From 21df5bfbc5587db504f3c8ea67efabe95beb4637 Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Mon, 28 Aug 2023 23:48:38 +0200 Subject: [PATCH] [NFC] Replace miopen::ProblemDescription with conv::ProblemDescription, part 3 (#2303) --- src/conv/heuristics/ai_heuristics.cpp | 70 +++--- src/conv/invokers/impl_gemm.cpp | 2 +- src/conv/invokers/impl_gemm_dynamic.cpp | 230 +++++++++--------- src/conv/invokers/mlir_impl_gemm.cpp | 58 ++--- src/conv/problem_description.cpp | 27 +- src/conv/solver_finders.cpp | 8 +- src/convolution.cpp | 5 +- src/fusion.cpp | 6 +- .../miopen/conv/compiled_in_parameters.hpp | 18 +- .../conv/invokers/impl_gemm_dynamic.hpp | 71 +++--- .../miopen/conv/problem_description.hpp | 117 +++++---- src/include/miopen/conv/solver_finders.hpp | 7 +- src/include/miopen/fusion/context.hpp | 2 +- .../miopen/fusion/problem_description.hpp | 2 +- src/include/miopen/mlo_internal.hpp | 3 - src/include/miopen/problem_description.hpp | 136 ++--------- src/include/miopen/solver.hpp | 60 +++-- .../miopen/solver/implicitgemm_util.hpp | 32 +-- .../problem_description_interpreter.hpp | 62 ++--- src/include/miopen/tensor.hpp | 2 +- src/mlo_dir_conv.cpp | 6 - src/problem_description.cpp | 27 +- src/solver/conv_MP_bidirectional_winograd.cpp | 30 +-- src/solver/conv_asm_1x1u.cpp | 121 ++++----- src/solver/conv_asm_1x1u_bias_activ_fused.cpp | 10 +- src/solver/conv_asm_1x1u_stride2.cpp | 128 +++++----- src/solver/conv_asm_3x3u.cpp | 58 ++--- src/solver/conv_asm_5x10u2v2b1.cpp | 40 +-- src/solver/conv_asm_5x10u2v2f1.cpp | 46 ++-- .../conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp | 32 +-- src/solver/conv_asm_dir_BwdWrW1x1.cpp | 86 +++---- src/solver/conv_asm_dir_BwdWrW3x3.cpp | 133 +++++----- ...onv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp | 26 +- src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp | 26 +- .../conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp | 134 +++++----- src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp | 26 +- .../conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp | 58 ++--- .../conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp | 120 ++++----- .../conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp | 141 ++++++----- .../conv_asm_implicit_gemm_v4r1_dynamic.cpp | 10 +- ...m_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp | 88 ++++--- ...onv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp | 99 ++++---- src/solver/conv_bin_wino3x3U.cpp | 26 +- src/solver/conv_bin_winoRxS.cpp | 36 +-- src/solver/conv_bin_winoRxS_fused.cpp | 36 +-- .../conv_ck_igemm_fwd_bias_activ_fused.cpp | 34 +-- .../conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp | 2 +- src/solver/conv_direct_naive_conv_bwd.cpp | 36 +-- src/solver/conv_direct_naive_conv_fwd.cpp | 24 +- src/solver/conv_direct_naive_conv_wrw.cpp | 36 +-- ...ip_implicit_gemm_3d_grouped_fwd_xdlops.cpp | 16 +- ...conv_hip_implicit_gemm_bwd_data_xdlops.cpp | 18 +- .../conv_hip_implicit_gemm_bwd_v1r1.cpp | 2 +- ...conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp | 5 +- .../conv_hip_implicit_gemm_bwd_v4r1.cpp | 2 +- ...conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp | 5 +- .../conv_hip_implicit_gemm_fwd_v4r1.cpp | 80 +++--- .../conv_hip_implicit_gemm_fwd_v4r4.cpp | 2 +- ...conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp | 5 +- ...licit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp | 17 +- ...conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp | 5 +- .../conv_hip_implicit_gemm_fwd_xdlops.cpp | 18 +- ...v_hip_implicit_gemm_grouped_fwd_xdlops.cpp | 16 +- .../conv_hip_implicit_gemm_wrw_v4r4.cpp | 2 +- ...conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp | 7 +- ...licit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp | 7 +- src/solver/conv_mlir_igemm_bwd.cpp | 2 +- src/solver/conv_mlir_igemm_bwd_xdlops.cpp | 2 +- src/solver/conv_mlir_igemm_fwd.cpp | 2 +- src/solver/conv_mlir_igemm_fwd_xdlops.cpp | 2 +- src/solver/conv_mlir_igemm_wrw.cpp | 2 +- src/solver/conv_mlir_igemm_wrw_xdlops.cpp | 2 +- src/solver/conv_multipass_wino3x3WrW.cpp | 42 ++-- src/solver/conv_ocl_dir2D11x11.cpp | 87 +++---- src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp | 102 ++++---- src/solver/conv_ocl_dir2D_bwdWrW_2.cpp | 138 +++++------ src/solver/conv_ocl_dir2D_bwdWrW_53.cpp | 211 ++++++++-------- src/solver/conv_ocl_dir2Dfwd.cpp | 130 +++++----- src/solver/conv_ocl_dir2Dfwd1x1.cpp | 109 +++++---- .../conv_ocl_dir2Dfwd_exhaustive_search.cpp | 60 ++--- src/solver/conv_ocl_dir2Dfwd_fused.cpp | 2 +- src/solver/conv_ocl_dir2Dfwdgen.cpp | 112 ++++----- src/solver/conv_winoRxS.cpp | 74 +++--- src/solver/conv_winoRxS_fused.cpp | 24 +- src/solver/conv_wino_fury_RxS.cpp | 8 +- src/solver/fft.cpp | 28 +-- src/solver/pooling/forwardNaive.cpp | 2 +- test/conv_common.hpp | 10 +- test/solver.cpp | 2 +- 89 files changed, 1904 insertions(+), 2049 deletions(-) diff --git a/src/conv/heuristics/ai_heuristics.cpp b/src/conv/heuristics/ai_heuristics.cpp index c7131552c5..50cd495de6 100644 --- a/src/conv/heuristics/ai_heuristics.cpp +++ b/src/conv/heuristics/ai_heuristics.cpp @@ -153,43 +153,42 @@ class Gfx908Model : public Model const ConvolutionContext& ctx) const override { // check if problem is of the kind TunaNet was trained to handle - if(!problem.conv_problem.Is2d()) + if(!problem.Is2d()) { MIOPEN_LOG_I2("TunaNet Inapplicable: Problem not 2D"); return false; } - if(problem.conv_problem.GetGroupCount() != 1) + if(problem.GetGroupCount() != 1) { MIOPEN_LOG_I2("TunaNet Inapplicable: Group count not 1"); return false; } - if(problem.conv_problem.GetInLayout() != "NCHW" && - problem.conv_problem.GetInLayout() != "NCDHW") + if(problem.GetInLayout() != "NCHW" && problem.GetInLayout() != "NCDHW") { MIOPEN_LOG_I2("TunaNet Inapplicable: Layout not supported"); return false; } - if(problem.conv_problem.GetWeightsHeight() != problem.conv_problem.GetWeightsWidth()) + if(problem.GetWeightsHeight_() != problem.GetWeightsWidth_()) { MIOPEN_LOG_I2("TunaNet Inapplicable: Filters must be square (fil_h == fil_w)"); return false; } - if(problem.conv_problem.GetPadH() != problem.conv_problem.GetPadW()) + if(problem.GetPadH() != problem.GetPadW()) { MIOPEN_LOG_I2("TunaNet Inapplicable: Padding must be equal along all axes"); return false; } - if(problem.conv_problem.GetKernelStrideH() != problem.conv_problem.GetKernelStrideW()) + if(problem.GetKernelStrideH() != problem.GetKernelStrideW()) { MIOPEN_LOG_I2("TunaNet Inapplicable: Stride must be equal along all axes"); return false; } - if(problem.conv_problem.GetDilationH() != 1 || problem.conv_problem.GetDilationW() != 1) + if(problem.GetDilationH() != 1 || problem.GetDilationW() != 1) { MIOPEN_LOG_I2("TunaNet Inapplicable: Dilation must be 1"); return false; } - const auto& data_type = problem.conv_problem.GetInDataType(); + const auto data_type = problem.GetInDataType(); if(data_type != miopenFloat && data_type != miopenHalf && data_type != miopenBFloat16) { MIOPEN_LOG_I2("TunaNet Inapplicable: Unsupported data type"); @@ -219,37 +218,34 @@ class Gfx908Model : public Model protected: std::vector ToFeatures(const ProblemDescription& problem) const override { - const auto& conv_problem = problem.conv_problem; - const bool isFwd = conv_problem.GetDirection() == conv::Direction::Forward; + const bool isFwd = problem.GetDirection() == conv::Direction::Forward; std::vector features = { - static_cast(isFwd ? conv_problem.GetInChannels() - : conv_problem.GetOutChannels()), - static_cast(isFwd ? conv_problem.GetInDepth() : conv_problem.GetOutDepth()), - static_cast(isFwd ? conv_problem.GetInHeight() : conv_problem.GetOutHeight()), - static_cast(isFwd ? conv_problem.GetInWidth() : conv_problem.GetOutWidth()), - static_cast(conv_problem.GetWeightsDepth()), - static_cast(conv_problem.GetWeightsHeight()), - static_cast(conv_problem.GetWeightsWidth()), - static_cast(isFwd ? conv_problem.GetOutChannels() - : conv_problem.GetInChannels()), - static_cast(isFwd ? conv_problem.GetOutDepth() : conv_problem.GetInDepth()), - static_cast(isFwd ? conv_problem.GetOutHeight() : conv_problem.GetInHeight()), - static_cast(isFwd ? conv_problem.GetOutWidth() : conv_problem.GetInWidth()), - static_cast(conv_problem.GetOutBatchSize()), + static_cast(isFwd ? problem.GetInChannels_() : problem.GetOutChannels_()), + static_cast(isFwd ? problem.GetInDepth_() : problem.GetOutDepth_()), + static_cast(isFwd ? problem.GetInHeight_() : problem.GetOutHeight_()), + static_cast(isFwd ? problem.GetInWidth_() : problem.GetOutWidth_()), + static_cast(problem.GetWeightsDepth_()), + static_cast(problem.GetWeightsHeight_()), + static_cast(problem.GetWeightsWidth_()), + static_cast(isFwd ? problem.GetOutChannels_() : problem.GetInChannels_()), + static_cast(isFwd ? problem.GetOutDepth_() : problem.GetInDepth_()), + static_cast(isFwd ? problem.GetOutHeight_() : problem.GetInHeight_()), + static_cast(isFwd ? problem.GetOutWidth_() : problem.GetInWidth_()), + static_cast(problem.GetOutBatchSize_()), static_cast(1), // TunaNet was trained on a dataset of 2D // problems where PadD was incorrectly set to 1 - static_cast(conv_problem.GetPadH()), - static_cast(conv_problem.GetPadW()), + static_cast(problem.GetPadH()), + static_cast(problem.GetPadW()), static_cast(1), // TunaNet was trained on a dataset of 2D // problems where StrideD was incorrectly set to 1 - static_cast(conv_problem.GetKernelStrideH()), - static_cast(conv_problem.GetKernelStrideW()), - static_cast(conv_problem.GetDilationH()), - static_cast(conv_problem.GetDilationW()), - static_cast(metadata.EncodeLayout(conv_problem.GetInLayout())), - static_cast(metadata.EncodePrecision(conv_problem.GetInDataType())), - static_cast(metadata.EncodeDirection(conv_problem.GetDirection())), - static_cast(conv_problem.GetGroupCount())}; + static_cast(problem.GetKernelStrideH()), + static_cast(problem.GetKernelStrideW()), + static_cast(problem.GetDilationH()), + static_cast(problem.GetDilationW()), + static_cast(metadata.EncodeLayout(problem.GetInLayout())), + static_cast(metadata.EncodePrecision(problem.GetInDataType())), + static_cast(metadata.EncodeDirection(problem.GetDirection())), + static_cast(problem.GetGroupCount())}; // normalize for(size_t i = 0; i < features.size(); ++i) @@ -271,7 +267,7 @@ std::vector PredictSolver(const ProblemDescription& problem, std::string est_name = ":memory:" + device; auto& db = AnyRamDb::GetCached(est_name); - auto db_res = db.FindRecord(problem.conv_problem); + auto db_res = db.FindRecord(static_cast(problem)); if(db_res) { MIOPEN_LOG_I2("Cached heuristic result found"); @@ -320,7 +316,7 @@ std::vector PredictSolver(const ProblemDescription& problem, sol.push_back(sol_id.Value()); any_sol.push_back(sol_id.Value()); } - db.StoreRecord(problem.conv_problem, any_sol); + db.StoreRecord(static_cast(problem), any_sol); if(miopen::IsLogging(LoggingLevel::Info2)) { std::stringstream ss; diff --git a/src/conv/invokers/impl_gemm.cpp b/src/conv/invokers/impl_gemm.cpp index 2133eb3256..649e153491 100644 --- a/src/conv/invokers/impl_gemm.cpp +++ b/src/conv/invokers/impl_gemm.cpp @@ -27,7 +27,7 @@ InvokerFactory MakeImplGemmDataInvokerFactory(const miopen::ProblemDescription& if(problem.direction.IsBackwardWrW()) MIOPEN_THROW("MakeImplGemmDataInvokerFactory shouldn't be used for WrW invokers."); - const auto& conv = problem.conv_problem.GetConv(); + const auto& conv = problem.GetConv(); const auto& lowp_quant = conv.lowp_quant; return [conv, lowp_quant](const std::vector& kernels) { diff --git a/src/conv/invokers/impl_gemm_dynamic.cpp b/src/conv/invokers/impl_gemm_dynamic.cpp index 456bdeccb4..01e931dd69 100644 --- a/src/conv/invokers/impl_gemm_dynamic.cpp +++ b/src/conv/invokers/impl_gemm_dynamic.cpp @@ -32,7 +32,7 @@ static inline uint32_t igemm_find_tile_size_with_upper_bound( } static float CallImplGemmDynamicForward1x1(const miopen::Handle& handle, - const ProblemDescription& conv_problem, + const ProblemDescription& problem, ConstData_t src, Data_t dst, ConstData_t wei, @@ -44,19 +44,19 @@ static float CallImplGemmDynamicForward1x1(const miopen::Handle& handle, MIOPEN_LOG_I(kernel.GetName()); // clang-format off - int hi = conv_problem.GetInHeight(); - int wi = conv_problem.GetInWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetOutChannels(); - int c = conv_problem.GetInChannels(); - int ho = conv_problem.GetOutHeight(); - int wo = conv_problem.GetOutWidth(); - int stride_h = conv_problem.GetKernelStrideH(); - int stride_w = conv_problem.GetKernelStrideW(); - int dilation_h = conv_problem.GetDilationH(); - int dilation_w = conv_problem.GetDilationW(); - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); + int hi = problem.GetInHeight_(); + int wi = problem.GetInWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetOutChannels_(); + int c = problem.GetInChannels_(); + int ho = problem.GetOutHeight_(); + int wo = problem.GetOutWidth_(); + int stride_h = problem.GetKernelStrideH(); + int stride_w = problem.GetKernelStrideW(); + int dilation_h = problem.GetDilationH(); + int dilation_w = problem.GetDilationW(); + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); int gap_0 = 0; // clang-format on @@ -89,8 +89,7 @@ static float CallImplGemmDynamicForward1x1(const miopen::Handle& handle, InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const miopen::ProblemDescription& problem) { - const auto& conv_problem = problem.conv_problem; - return [conv_problem](const std::vector& kernels) { + return [problem](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { decltype(auto) data_ctx = primitive_parameters.CastTo(); const auto& tensors = data_ctx.tensors; @@ -103,7 +102,7 @@ MakeImplGemmDynamicForward1x1InvokerFactory(const miopen::ProblemDescription& pr [&](const Kernel& k) { return handle.Run(k); }); float elapsed = 0; elapsed = CallImplGemmDynamicForward1x1( - handle, conv_problem, tensors.in, tensors.out, tensors.w, ks); + handle, problem, tensors.in, tensors.out, tensors.w, ks); if(handle.IsProfilingEnabled()) { handle.ResetKernelTime(); @@ -118,22 +117,21 @@ InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const miopen::ProblemDescription& problem, const int& cfg) { - const auto& conv_problem = problem.conv_problem; - int hi = conv_problem.GetOutHeight(); - int wi = conv_problem.GetOutWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetInChannels(); - int c = conv_problem.GetOutChannels(); - int ho = conv_problem.GetInHeight(); - int wo = conv_problem.GetInWidth(); - int stride_h = conv_problem.GetInHeight() > 1 ? conv_problem.GetKernelStrideH() : 1; - int stride_w = conv_problem.GetInWidth() > 1 ? conv_problem.GetKernelStrideW() : 1; - int dilation_h = conv_problem.GetWeightsHeight() > 1 ? conv_problem.GetDilationH() : 1; - int dilation_w = conv_problem.GetWeightsWidth() > 1 ? conv_problem.GetDilationW() : 1; - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetWeightsHeight(); - int x = conv_problem.GetWeightsWidth(); + int hi = problem.GetOutHeight_(); + int wi = problem.GetOutWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); + int stride_h = problem.GetInHeight_() > 1 ? problem.GetKernelStrideH() : 1; + int stride_w = problem.GetInWidth_() > 1 ? problem.GetKernelStrideW() : 1; + int dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + int dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h); int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w); @@ -256,23 +254,22 @@ InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory( const miopen::ProblemDescription& problem, const solver::TunableImplicitGemmGTCDynamic_t& cfg) { - const auto& conv_problem = problem.conv_problem; - int hi = conv_problem.GetOutHeight(); - int wi = conv_problem.GetOutWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetInChannels(); - int c = conv_problem.GetOutChannels(); - int ho = conv_problem.GetInHeight(); - int wo = conv_problem.GetInWidth(); - int stride_h = conv_problem.GetOutHeight() > 1 ? conv_problem.GetKernelStrideH() : 1; - int stride_w = conv_problem.GetOutWidth() > 1 ? conv_problem.GetKernelStrideW() : 1; - int dilation_h = conv_problem.GetWeightsHeight() > 1 ? conv_problem.GetDilationH() : 1; - int dilation_w = conv_problem.GetWeightsWidth() > 1 ? conv_problem.GetDilationW() : 1; - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetWeightsHeight(); - int x = conv_problem.GetWeightsWidth(); - int group = conv_problem.GetGroupCount(); + int hi = problem.GetOutHeight_(); + int wi = problem.GetOutWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); + int stride_h = problem.GetOutHeight_() > 1 ? problem.GetKernelStrideH() : 1; + int stride_w = problem.GetOutWidth_() > 1 ? problem.GetKernelStrideW() : 1; + int dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + int dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); + int group = problem.GetGroupCount(); int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h); int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w); @@ -445,26 +442,25 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( const miopen::ProblemDescription& problem, const solver::PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC& config) { - const auto& conv_problem = problem.conv_problem; - int hi = conv_problem.GetInHeight(); - int wi = conv_problem.GetInWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetOutChannels(); - int c = conv_problem.GetInChannels(); - int ho = conv_problem.GetOutHeight(); - int wo = conv_problem.GetOutWidth(); - int stride_h = conv_problem.GetKernelStrideH(); - int stride_w = conv_problem.GetKernelStrideW(); - int dilation_h = conv_problem.GetDilationH(); - int dilation_w = conv_problem.GetDilationW(); - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetWeightsHeight(); - int x = conv_problem.GetWeightsWidth(); - int group = conv_problem.GetGroupCount(); - int c_karg = c / group; - int y_karg = y; - int x_karg = x; + int hi = problem.GetInHeight_(); + int wi = problem.GetInWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetOutChannels_(); + int c = problem.GetInChannels_(); + int ho = problem.GetOutHeight_(); + int wo = problem.GetOutWidth_(); + int stride_h = problem.GetKernelStrideH(); + int stride_w = problem.GetKernelStrideW(); + int dilation_h = problem.GetDilationH(); + int dilation_w = problem.GetDilationW(); + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); + int group = problem.GetGroupCount(); + int c_karg = c / group; + int y_karg = y; + int x_karg = x; int splits_4G = solver::igemm_split_batch_size( hi, wi, ho, wo, n, k, c, miopen::GetTypeSize(problem.GetInDataType())); @@ -539,14 +535,14 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( std::vector> opArgsTrans; - const auto lowp_quant = problem.conv_problem.GetConv().lowp_quant; + const auto lowp_quant = problem.GetConv().lowp_quant; const auto isGfx90aFp16altSupport = - (ctx.GetStream().GetDeviceName() == "gfx90a") && conv_problem.IsFp16(); + (ctx.GetStream().GetDeviceName() == "gfx90a") && problem.IsFp16(); const bool need_cast = [&]() { - if(problem.conv_problem.GetOut().GetType() == miopenHalf) + if(problem.GetOut().GetType() == miopenHalf) return use_fp32_global_split_on_fp16; - if(problem.conv_problem.GetOut().GetType() == miopenBFloat16) + if(problem.GetOut().GetType() == miopenBFloat16) return need_set_zero; return false; }(); @@ -619,9 +615,8 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( const int kID_trans_start = isGfx90aFp16altSupport ? 2 : 1; - const TensorDescriptor cast_desc(miopenFloat, - problem.conv_problem.GetOut().GetLengths(), - problem.conv_problem.GetOut().GetStrides()); + const TensorDescriptor cast_desc( + miopenFloat, problem.GetOut().GetLengths(), problem.GetOut().GetStrides()); auto null_buf = shared{}; return [=](const std::vector& kernels) mutable { @@ -740,23 +735,22 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( const miopen::ProblemDescription& problem, const solver::PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC& config) { - const auto& conv_problem = problem.conv_problem; - int hi = conv_problem.GetOutHeight(); - int wi = conv_problem.GetOutWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetInChannels(); - int c = conv_problem.GetOutChannels(); - int ho = conv_problem.GetInHeight(); - int wo = conv_problem.GetInWidth(); - int stride_h = conv_problem.GetOutHeight() > 1 ? conv_problem.GetKernelStrideH() : 1; - int stride_w = conv_problem.GetOutWidth() > 1 ? conv_problem.GetKernelStrideW() : 1; - int dilation_h = conv_problem.GetWeightsHeight() > 1 ? conv_problem.GetDilationH() : 1; - int dilation_w = conv_problem.GetWeightsWidth() > 1 ? conv_problem.GetDilationW() : 1; - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetWeightsHeight(); - int x = conv_problem.GetWeightsWidth(); - int group = conv_problem.GetGroupCount(); + int hi = problem.GetOutHeight_(); + int wi = problem.GetOutWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); + int stride_h = problem.GetOutHeight_() > 1 ? problem.GetKernelStrideH() : 1; + int stride_w = problem.GetOutWidth_() > 1 ? problem.GetKernelStrideW() : 1; + int dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + int dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); + int group = problem.GetGroupCount(); int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h); int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w); @@ -859,13 +853,13 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( std::vector> opArgsTrans; - const auto lowp_quant = problem.conv_problem.GetConv().lowp_quant; + const auto lowp_quant = problem.GetConv().lowp_quant; const auto isGfx90aFp16altSupport = - (ctx.GetStream().GetDeviceName() == "gfx90a") && conv_problem.IsFp16(); + (ctx.GetStream().GetDeviceName() == "gfx90a") && problem.IsFp16(); const bool need_cast = [&]() { - if(problem.conv_problem.GetOut().GetType() == miopenHalf) + if(problem.GetOut().GetType() == miopenHalf) return use_fp32_global_split_on_fp16; - if(problem.conv_problem.GetOut().GetType() == miopenBFloat16) + if(problem.GetOut().GetType() == miopenBFloat16) return need_set_zero; return false; }(); @@ -938,9 +932,8 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( const int kID_trans_start = isGfx90aFp16altSupport ? 2 : 1; - const TensorDescriptor cast_desc(miopenFloat, - problem.conv_problem.GetOut().GetLengths(), - problem.conv_problem.GetOut().GetStrides()); + const TensorDescriptor cast_desc( + miopenFloat, problem.GetOut().GetLengths(), problem.GetOut().GetStrides()); auto null_buf = shared{}; return [=](const std::vector& kernels) mutable { @@ -1057,24 +1050,23 @@ InvokerFactory MakeImplGemmDynamicForwardDlopsNCHWCInvokerFactory( const miopen::ProblemDescription& problem, const solver::PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC& config) { - const auto& conv_problem = problem.conv_problem; - int hi = conv_problem.GetInHeight(); - int wi = conv_problem.GetInWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetOutChannels() * config.vector_c; - int c = conv_problem.GetInChannels(); - int ks = 1; - int ho = conv_problem.GetOutHeight(); - int wo = conv_problem.GetOutWidth(); - int stride_h = conv_problem.GetKernelStrideH(); - int stride_w = conv_problem.GetKernelStrideW(); - int dilation_h = conv_problem.GetDilationH(); - int dilation_w = conv_problem.GetDilationW(); - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetWeightsHeight(); - int x = conv_problem.GetWeightsWidth(); - int group = conv_problem.GetGroupCount(); + int hi = problem.GetInHeight_(); + int wi = problem.GetInWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetOutChannels_() * config.vector_c; + int c = problem.GetInChannels_(); + int ks = 1; + int ho = problem.GetOutHeight_(); + int wo = problem.GetOutWidth_(); + int stride_h = problem.GetKernelStrideH(); + int stride_w = problem.GetKernelStrideW(); + int dilation_h = problem.GetDilationH(); + int dilation_w = problem.GetDilationW(); + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); + int group = problem.GetGroupCount(); // Currentlly we do not tile in H/W dimension, using tile H/W as Ho/Wo, Thus Number of Tile // equal to one diff --git a/src/conv/invokers/mlir_impl_gemm.cpp b/src/conv/invokers/mlir_impl_gemm.cpp index fb4b3f6f9b..541975c36f 100644 --- a/src/conv/invokers/mlir_impl_gemm.cpp +++ b/src/conv/invokers/mlir_impl_gemm.cpp @@ -74,7 +74,7 @@ struct MlirConvArgs #endif #if MIIR_BARE_POINTER_ABI -void ComputeMlirDimsStrides(const conv::ProblemDescription& conv_problem, +void ComputeMlirDimsStrides(const conv::ProblemDescription& problem, std::vector& in_dims, std::vector& in_strides, std::vector& weights_dims, @@ -83,7 +83,7 @@ void ComputeMlirDimsStrides(const conv::ProblemDescription& conv_problem, std::vector& out_strides) { // The bare pointer ABI doesn't require this info, so this is a noop. - (void)conv_problem; + (void)problem; (void)in_dims; (void)in_strides; (void)weights_dims; @@ -156,7 +156,7 @@ void InsertGToDimsStrides(const std::string& layout, strides.insert(strides.begin() + index, strides[index] * dims[index + 1]); } -void ComputeMlirDimsStrides(const conv::ProblemDescription& conv_problem, +void ComputeMlirDimsStrides(const conv::ProblemDescription& problem, std::vector& in_dims, std::vector& in_strides, std::vector& weights_dims, @@ -164,13 +164,13 @@ void ComputeMlirDimsStrides(const conv::ProblemDescription& conv_problem, std::vector& out_dims, std::vector& out_strides) { - auto group_count = conv_problem.GetGroupCount(); + auto group_count = problem.GetGroupCount(); TensorDescriptor in; - if(conv_problem.GetDirection() == conv::Direction::Forward) - in = conv_problem.GetIn(); + if(problem.GetDirection() == conv::Direction::Forward) + in = problem.GetIn(); else - in = conv_problem.GetOut(); + in = problem.GetOut(); in_dims = in.GetLengths(); in_strides = in.GetStrides(); @@ -179,7 +179,7 @@ void ComputeMlirDimsStrides(const conv::ProblemDescription& conv_problem, InsertGToDimsStrides(in.GetLayout("NCHW"), 'C', group_count, in_dims, in_strides); // Add a virtual group dimension before output channel. - const TensorDescriptor& weights = conv_problem.GetWeights(); + const TensorDescriptor& weights = problem.GetWeights(); weights_dims = weights.GetLengths(); weights_strides = weights.GetStrides(); PermuteDimsStrides(weights_dims, weights_strides); @@ -187,10 +187,10 @@ void ComputeMlirDimsStrides(const conv::ProblemDescription& conv_problem, weights.GetLayout("NCHW"), 'N', group_count, weights_dims, weights_strides); TensorDescriptor out; - if(conv_problem.GetDirection() == conv::Direction::Forward) - out = conv_problem.GetOut(); + if(problem.GetDirection() == conv::Direction::Forward) + out = problem.GetOut(); else - out = conv_problem.GetIn(); + out = problem.GetIn(); out_dims = out.GetLengths(); out_strides = out.GetStrides(); @@ -295,28 +295,22 @@ InvokerFactory MakeMlirFwdInvokerFactory(const miopen::ProblemDescription& probl std::vector in_dims, in_strides; std::vector weights_dims, weights_strides; std::vector out_dims, out_strides; - ComputeMlirDimsStrides(problem.conv_problem, - in_dims, - in_strides, - weights_dims, - weights_strides, - out_dims, - out_strides); + ComputeMlirDimsStrides( + problem, in_dims, in_strides, weights_dims, weights_strides, out_dims, out_strides); MlirConvArgs args = MakeMlirConvArgs( in_dims, in_strides, weights_dims, weights_strides, out_dims, out_strides, 0); - const auto& conv = problem.conv_problem.GetConv(); + const auto& conv = problem.GetConv(); const auto& lowp_quant = conv.lowp_quant; - const auto& outDesc = problem.conv_problem.GetOut(); + const auto& outDesc = problem.GetOut(); TensorDescriptor outConvDesc = outDesc; // outConvDesc is only functional when this is a int8 convolution. It allows the output type to // be cast to a different than int32_t. This gives the solver a wider applicable range and // mimics the behavior of the gemm solver. bool needs_output_cast = false; - if(problem.conv_problem.GetIn().GetType() == miopenInt8 && - problem.conv_problem.GetWeights().GetType() == miopenInt8 && - problem.conv_problem.GetOut().GetType() != miopenInt32) + if(problem.GetIn().GetType() == miopenInt8 && problem.GetWeights().GetType() == miopenInt8 && + problem.GetOut().GetType() != miopenInt32) { needs_output_cast = true; outConvDesc = TensorDescriptor(miopenInt32, outDesc.GetLengths(), outDesc.GetStrides()); @@ -367,13 +361,8 @@ InvokerFactory MakeMlirBwdInvokerFactory(const miopen::ProblemDescription& probl std::vector in_dims, in_strides; std::vector weights_dims, weights_strides; std::vector out_dims, out_strides; - ComputeMlirDimsStrides(problem.conv_problem, - in_dims, - in_strides, - weights_dims, - weights_strides, - out_dims, - out_strides); + ComputeMlirDimsStrides( + problem, in_dims, in_strides, weights_dims, weights_strides, out_dims, out_strides); MlirConvArgs args = MakeMlirConvArgs( in_dims, in_strides, weights_dims, weights_strides, out_dims, out_strides, 0); @@ -428,13 +417,8 @@ InvokerFactory MakeMlirWrWInvokerFactory(const miopen::ProblemDescription& probl std::vector in_dims, in_strides; std::vector weights_dims, weights_strides; std::vector out_dims, out_strides; - ComputeMlirDimsStrides(problem.conv_problem, - in_dims, - in_strides, - weights_dims, - weights_strides, - out_dims, - out_strides); + ComputeMlirDimsStrides( + problem, in_dims, in_strides, weights_dims, weights_strides, out_dims, out_strides); MlirConvArgs args = MakeMlirConvArgs( in_dims, in_strides, weights_dims, weights_strides, out_dims, out_strides, workspace_req); diff --git a/src/conv/problem_description.cpp b/src/conv/problem_description.cpp index b5bfb5e2d8..e44160b4d5 100644 --- a/src/conv/problem_description.cpp +++ b/src/conv/problem_description.cpp @@ -48,7 +48,7 @@ namespace conv { namespace { std::function -PrintDHW(char sep, int spatial_dims, int depth, int height, int width) +PrintDHW(char sep, unsigned spatial_dims, int depth, int height, int width) { return [=](std::ostream& stream) { if(spatial_dims > 2) @@ -107,13 +107,14 @@ void ProblemDescription::BuildConfKey(std::string& conf_key) const { std::ostringstream ss; - ss << GetInChannels(); - ss << 'x' << PrintDHW('x', GetSpatialDims(), GetInDepth(), GetInHeight(), GetInWidth()); + ss << GetInChannels_(); + ss << 'x' << PrintDHW('x', GetSpatialDims(), GetInDepth_(), GetInHeight_(), GetInWidth_()); ss << 'x' - << PrintDHW('x', GetSpatialDims(), GetWeightsDepth(), GetWeightsHeight(), GetWeightsWidth()); - ss << 'x' << GetOutChannels(); - ss << 'x' << PrintDHW('x', GetSpatialDims(), GetOutDepth(), GetOutHeight(), GetOutWidth()); - ss << 'x' << GetInBatchSize(); + << PrintDHW( + 'x', GetSpatialDims(), GetWeightsDepth_(), GetWeightsHeight_(), GetWeightsWidth_()); + ss << 'x' << GetOutChannels_(); + ss << 'x' << PrintDHW('x', GetSpatialDims(), GetOutDepth_(), GetOutHeight_(), GetOutWidth_()); + ss << 'x' << GetInBatchSize_(); if((GetInLayout() == "NCHW" && GetWeightsLayout() == "NCHW" && GetOutLayout() == "NCHW") || (GetInLayout() == "NCDHW" && GetWeightsLayout() == "NCDHW" && GetOutLayout() == "NCDHW")) { @@ -145,12 +146,12 @@ void ProblemDescription::Serialize(std::ostream& stream) const // Problem description with non-default layout // 576-4-4-1x1-192-4-4-8-1x1-2x2-3x3-0-NHWC-NCHW-NCHW-FP32-F // clang-format off - stream << GetInChannels(); - stream << sep << PrintDHW(sep, GetSpatialDims(), GetInDepth(), GetInHeight(), GetInWidth()); - stream << sep << PrintDHW('x', GetSpatialDims(), GetWeightsDepth(), GetWeightsHeight(), GetWeightsWidth()); - stream << sep << GetOutChannels(); - stream << sep << PrintDHW(sep, GetSpatialDims(), GetOutDepth(), GetOutHeight(), GetOutWidth()); - stream << sep << GetInBatchSize(); + stream << GetInChannels_(); + stream << sep << PrintDHW(sep, GetSpatialDims(), GetInDepth_(), GetInHeight_(), GetInWidth_()); + stream << sep << PrintDHW('x', GetSpatialDims(), GetWeightsDepth_(), GetWeightsHeight_(), GetWeightsWidth_()); + stream << sep << GetOutChannels_(); + stream << sep << PrintDHW(sep, GetSpatialDims(), GetOutDepth_(), GetOutHeight_(), GetOutWidth_()); + stream << sep << GetInBatchSize_(); stream << sep << PrintDHW('x', GetSpatialDims(), GetPadD(), GetPadH(), GetPadW()); stream << sep << PrintDHW('x', GetSpatialDims(), GetKernelStrideD(), GetKernelStrideH(), GetKernelStrideW()); stream << sep << PrintDHW('x', GetSpatialDims(), GetDilationD(), GetDilationH(), GetDilationW()); diff --git a/src/conv/solver_finders.cpp b/src/conv/solver_finders.cpp index c4bf130222..adec061aa5 100644 --- a/src/conv/solver_finders.cpp +++ b/src/conv/solver_finders.cpp @@ -63,7 +63,7 @@ class DirectSolverFinder : public SolversFinder const AnyInvokeParams& invoke_ctx, bool /*use_winograd_only*/) const override { - return problem.conv_problem.GetDirection() != conv::Direction::BackwardWeights + return problem.GetDirection() != conv::Direction::BackwardWeights ? FindAllDirectSolutions(ctx, problem, invoke_ctx) : FindAllBwdWrW2DSolutions(ctx, problem, invoke_ctx); } @@ -91,7 +91,7 @@ class ImplicitGemmSolverFinder : public SolversFinder const AnyInvokeParams& invoke_ctx, bool /*use_winograd_only*/) const override { - return problem.conv_problem.GetDirection() != conv::Direction::BackwardWeights + return problem.GetDirection() != conv::Direction::BackwardWeights ? FindAllImplicitGemmSolutions(ctx, problem, invoke_ctx) : FindImplicitGemmWrWAllSolutions(ctx, problem, invoke_ctx); } @@ -175,7 +175,7 @@ class WinogradSolverFinder : public SolversFinder auto ctx_copy = ctx; if(use_winograd_only) ctx_copy.use_dynamic_solutions_only = true; - return problem.conv_problem.GetDirection() != conv::Direction::BackwardWeights + return problem.GetDirection() != conv::Direction::BackwardWeights ? FindAllWinogradSolutions(ctx_copy, problem, invoke_ctx) : FindWinogradWrWAllSolutions(ctx_copy, problem, invoke_ctx); } @@ -288,7 +288,7 @@ void ConvFindCore(const AnyInvokeParams& invoke_ctx, auto solutions = std::map>{}; std::transform( finders.begin(), finders.end(), std::inserter(solutions, solutions.end()), [&](auto&& f) { - return std::make_pair(f->GetAlgorithmName(problem.conv_problem), + return std::make_pair(f->GetAlgorithmName(problem), f->Find(ctx, problem, invoke_ctx, use_winograd_only)); }); diff --git a/src/convolution.cpp b/src/convolution.cpp index 07986350f2..cb6cde5eda 100644 --- a/src/convolution.cpp +++ b/src/convolution.cpp @@ -80,8 +80,7 @@ std::size_t GetWorkSpaceSizeGEMM(const miopen::ConvolutionContext& ctx, { #if MIOPEN_USE_GEMM if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_GEMM{}) || - miopen::any_of(problem.conv_problem.GetConv().GetConvDilations(), - [](auto v) { return v > 1; })) + miopen::any_of(problem.GetConv().GetConvDilations(), [](auto v) { return v > 1; })) return 0; return GetMaxWorkSpaceSize(AllGemmWorkspaceSize(ctx, problem)); @@ -395,7 +394,7 @@ bool ConvolutionDescriptor::IsWinograd3x3SupportedAndFast(const miopen::Convolut return false; // Filter out configs where 3x3 Winograd does not have high WTI. - if(!(problem.GetOutChannels() >= 16 && problem.GetOutChannels() % 2 == 0)) + if(!(problem.GetOutChannels_() >= 16 && problem.GetOutChannels_() % 2 == 0)) return false; return solver::ConvBinWinograd3x3U{}.IsApplicable(ctx, problem); diff --git a/src/fusion.cpp b/src/fusion.cpp index 170e1b2d50..3ec9dd0ec1 100644 --- a/src/fusion.cpp +++ b/src/fusion.cpp @@ -136,7 +136,7 @@ static auto AllocateBuffersAndMakeConvBiasActivFusionInvokeParams( << " , size: " << conv_problem.GetWeightsSize() << " , out addr: " << invoke_bufs[3].get() << " , size: " << conv_problem.GetOutSize()); - const auto gfx90aaltimpl = conv_problem.conv_problem.GetConv().attribute.gfx90aFp16alt.GetFwd(); + const auto gfx90aaltimpl = conv_problem.GetConv().attribute.gfx90aFp16alt.GetFwd(); auto conv_data = std::make_unique(invoke_bufs[2].get()); @@ -153,9 +153,9 @@ static auto AllocateBuffersAndMakeConvBiasActivFusionInvokeParams( params.SetArg(2, std::move(activ_data)); return miopen::fusion::FusionInvokeParams(params, - conv_problem.conv_problem.GetIn(), + conv_problem.GetIn(), invoke_bufs[1].get(), - conv_problem.conv_problem.GetOut(), + conv_problem.GetOut(), invoke_bufs[3].get(), gfx90aaltimpl); } diff --git a/src/include/miopen/conv/compiled_in_parameters.hpp b/src/include/miopen/conv/compiled_in_parameters.hpp index 1693273806..28def48761 100644 --- a/src/include/miopen/conv/compiled_in_parameters.hpp +++ b/src/include/miopen/conv/compiled_in_parameters.hpp @@ -46,11 +46,11 @@ inline void GetCompiledInParameters(const ExecutionContext& ctx, int* const n_groups) { assert(N && C && H && W && K && n_groups); - *N = problem.GetBatchSize(); - *C = problem.GetInChannels(); - *H = problem.GetInHeight(); - *W = problem.GetInWidth(); - *K = problem.GetOutChannels(); + *N = problem.GetBatchSize_(); + *C = problem.GetInChannels_(); + *H = problem.GetInHeight_(); + *W = problem.GetInWidth_(); + *K = problem.GetOutChannels_(); *n_groups = ctx.GetStream().GetMaxComputeUnits(); } @@ -67,8 +67,8 @@ inline void GetCompiledInParameters(const ExecutionContext& ctx, { GetCompiledInParameters(ctx, problem, N, C, H, W, K, n_groups); assert(out_H && out_W); - *out_H = problem.GetOutHeight(); - *out_W = problem.GetOutWidth(); + *out_H = problem.GetOutHeight_(); + *out_W = problem.GetOutWidth_(); } inline void GetCompiledInParameters(const ExecutionContext& ctx, @@ -88,8 +88,8 @@ inline void GetCompiledInParameters(const ExecutionContext& ctx, { GetCompiledInParameters(ctx, problem, N, C, H, W, K, n_groups, out_H, out_W); assert(filter_size_H && filter_size_W && pad_H && pad_W); - *filter_size_H = problem.GetWeightsHeight(); - *filter_size_W = problem.GetWeightsWidth(); + *filter_size_H = problem.GetWeightsHeight_(); + *filter_size_W = problem.GetWeightsWidth_(); *pad_H = problem.direction.IsForward() ? problem.GetPadH() : problem.GetBackwardPadH(); *pad_W = problem.direction.IsForward() ? problem.GetPadW() : problem.GetBackwardPadW(); } diff --git a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp index 4ee5ae7b55..e2d329b0a9 100644 --- a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp +++ b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp @@ -42,29 +42,29 @@ namespace conv { template inline std::vector -ComputeDynamicIGemmForwardKernelArgs(const ProblemDescription& conv_problem, const T& cfg); +ComputeDynamicIGemmForwardKernelArgs(const ProblemDescription& problem, const T& cfg); template <> inline std::vector -ComputeDynamicIGemmForwardKernelArgs(const ProblemDescription& conv_problem, const int& cfg) +ComputeDynamicIGemmForwardKernelArgs(const ProblemDescription& problem, const int& cfg) { std::vector opArgs; // clang-format off - int hi = conv_problem.GetInHeight(); - int wi = conv_problem.GetInWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetOutChannels(); - int c = conv_problem.GetInChannels(); - int ho = conv_problem.GetOutHeight(); - int wo = conv_problem.GetOutWidth(); - int stride_h = conv_problem.GetKernelStrideH(); - int stride_w = conv_problem.GetKernelStrideW(); - int dilation_h = conv_problem.GetDilationH(); - int dilation_w = conv_problem.GetDilationW(); - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetWeightsHeight(); - int x = conv_problem.GetWeightsWidth(); + int hi = problem.GetInHeight_(); + int wi = problem.GetInWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetOutChannels_(); + int c = problem.GetInChannels_(); + int ho = problem.GetOutHeight_(); + int wo = problem.GetOutWidth_(); + int stride_h = problem.GetKernelStrideH(); + int stride_w = problem.GetKernelStrideW(); + int dilation_h = problem.GetDilationH(); + int dilation_w = problem.GetDilationW(); + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); int pack0 = cfg; // clang-format on @@ -94,26 +94,26 @@ ComputeDynamicIGemmForwardKernelArgs(const ProblemDescription& conv_problem template <> inline std::vector ComputeDynamicIGemmForwardKernelArgs( - const ProblemDescription& conv_problem, const solver::TunableImplicitGemmGTCDynamic_t& cfg) + const ProblemDescription& problem, const solver::TunableImplicitGemmGTCDynamic_t& cfg) { std::vector opArgs; // clang-format off - int hi = conv_problem.GetInHeight(); - int wi = conv_problem.GetInWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetOutChannels(); - int c = conv_problem.GetInChannels(); - int ho = conv_problem.GetOutHeight(); - int wo = conv_problem.GetOutWidth(); - int stride_h = conv_problem.GetKernelStrideH(); - int stride_w = conv_problem.GetKernelStrideW(); - int dilation_h = conv_problem.GetDilationH(); - int dilation_w = conv_problem.GetDilationW(); - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetWeightsHeight(); - int x = conv_problem.GetWeightsWidth(); - int group = conv_problem.GetGroupCount(); + int hi = problem.GetInHeight_(); + int wi = problem.GetInWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetOutChannels_(); + int c = problem.GetInChannels_(); + int ho = problem.GetOutHeight_(); + int wo = problem.GetOutWidth_(); + int stride_h = problem.GetKernelStrideH(); + int stride_w = problem.GetKernelStrideW(); + int dilation_h = problem.GetDilationH(); + int dilation_w = problem.GetDilationW(); + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); + int group = problem.GetGroupCount(); int pack0 = 0; // clang-format on @@ -188,8 +188,7 @@ template static inline InvokerFactory MakeImplGemmDynamicForwardInvokerFactory(const miopen::ProblemDescription& problem, const T& cfg) { - const auto& conv_problem = problem.conv_problem; - auto opArgs = ComputeDynamicIGemmForwardKernelArgs(conv_problem, cfg); + auto opArgs = ComputeDynamicIGemmForwardKernelArgs(problem, cfg); return [opArgs](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { decltype(auto) data_ctx = primitive_parameters.CastTo(); diff --git a/src/include/miopen/conv/problem_description.hpp b/src/include/miopen/conv/problem_description.hpp index 98af5746ad..c458f3421d 100644 --- a/src/include/miopen/conv/problem_description.hpp +++ b/src/include/miopen/conv/problem_description.hpp @@ -61,7 +61,7 @@ inline std::string GetDataTypeName(miopenDataType_t data_type) } template -constexpr auto GetDHW(int spatial_dims, const std::vector& data) +constexpr auto GetDHW(unsigned spatial_dims, const std::vector& data) { if(spatial_dims == 2) return std::make_tuple(0, data[0], data[1]); @@ -69,19 +69,19 @@ constexpr auto GetDHW(int spatial_dims, const std::vector& data) } template -constexpr TElement GetD3(int spatial_dims, const std::vector& data) +constexpr TElement GetD3(unsigned spatial_dims, const std::vector& data) { return std::get<0>(GetDHW(spatial_dims, data)); } template -constexpr TElement GetH3(int spatial_dims, const std::vector& data) +constexpr TElement GetH3(unsigned spatial_dims, const std::vector& data) { return std::get<1>(GetDHW(spatial_dims, data)); } template -constexpr TElement GetW3(int spatial_dims, const std::vector& data) +constexpr TElement GetW3(unsigned spatial_dims, const std::vector& data) { return std::get<2>(GetDHW(spatial_dims, data)); } @@ -116,31 +116,31 @@ constexpr TElement GetWofCHWN(const std::vector& data) } template -constexpr TElement GetN5(int spatial_dims, const std::vector& data) +constexpr TElement GetN5(unsigned spatial_dims, const std::vector& data) { return std::get<0>(GetNCDHW(spatial_dims, data)); } template -constexpr TElement GetC5(int spatial_dims, const std::vector& data) +constexpr TElement GetC5(unsigned spatial_dims, const std::vector& data) { return std::get<1>(GetNCDHW(spatial_dims, data)); } template -constexpr TElement GetD5(int spatial_dims, const std::vector& data) +constexpr TElement GetD5(unsigned spatial_dims, const std::vector& data) { return std::get<2>(GetNCDHW(spatial_dims, data)); } template -constexpr TElement GetH5(int spatial_dims, const std::vector& data) +constexpr TElement GetH5(unsigned spatial_dims, const std::vector& data) { return std::get<3>(GetNCDHW(spatial_dims, data)); } template -constexpr TElement GetW5(int spatial_dims, const std::vector& data) +constexpr TElement GetW5(unsigned spatial_dims, const std::vector& data) { return std::get<4>(GetNCDHW(spatial_dims, data)); } @@ -175,7 +175,7 @@ struct ProblemDescription : ProblemDescriptionBase } // Conv descriptor getters - std::size_t GetSpatialDims() const { return conv.GetSpatialDimension(); } + unsigned GetSpatialDims() const { return conv.GetSpatialDimension(); } int GetPadD() const { return GetD3(GetSpatialDims(), conv.GetConvPads()); } int GetPadH() const { return GetH3(GetSpatialDims(), conv.GetConvPads()); } int GetPadW() const { return GetW3(GetSpatialDims(), conv.GetConvPads()); } @@ -190,17 +190,17 @@ struct ProblemDescription : ProblemDescriptionBase // In getters miopenDataType_t GetInDataType() const { return in.GetType(); } - std::size_t GetInBatchSize() const { return GetN5(GetSpatialDims(), in.GetLengths()); } - std::size_t GetBatchSize() const { return GetInBatchSize(); } // alias of GetInBatchSize() - std::size_t GetInChannels() const { return GetC5(GetSpatialDims(), in.GetLengths()); } - std::size_t GetInDepth() const { return GetD5(GetSpatialDims(), in.GetLengths()); } - std::size_t GetInHeight() const { return GetH5(GetSpatialDims(), in.GetLengths()); } - std::size_t GetInWidth() const { return GetW5(GetSpatialDims(), in.GetLengths()); } - std::size_t GetInBatchStride() const { return GetN5(GetSpatialDims(), in.GetStrides()); } - std::size_t GetInChannelStride() const { return GetC5(GetSpatialDims(), in.GetStrides()); } - std::size_t GetInStrideD() const { return GetD5(GetSpatialDims(), in.GetStrides()); } - std::size_t GetInStrideH() const { return GetH5(GetSpatialDims(), in.GetStrides()); } - std::size_t GetInStrideW() const { return GetW5(GetSpatialDims(), in.GetStrides()); } + unsigned GetInBatchSize_() const { return GetN5(GetSpatialDims(), in.GetLengths()); } + unsigned GetBatchSize_() const { return GetInBatchSize_(); } // alias of GetInBatchSize_() + unsigned GetInChannels_() const { return GetC5(GetSpatialDims(), in.GetLengths()); } + unsigned GetInDepth_() const { return GetD5(GetSpatialDims(), in.GetLengths()); } + unsigned GetInHeight_() const { return GetH5(GetSpatialDims(), in.GetLengths()); } + unsigned GetInWidth_() const { return GetW5(GetSpatialDims(), in.GetLengths()); } + unsigned GetInBatchStride_() const { return GetN5(GetSpatialDims(), in.GetStrides()); } + unsigned GetInChannelStride_() const { return GetC5(GetSpatialDims(), in.GetStrides()); } + unsigned GetInStrideD_() const { return GetD5(GetSpatialDims(), in.GetStrides()); } + unsigned GetInStrideH_() const { return GetH5(GetSpatialDims(), in.GetStrides()); } + unsigned GetInStrideW_() const { return GetW5(GetSpatialDims(), in.GetStrides()); } std::string GetInLayout() const { return in_layout; } std::string ComputeInLayout() const { @@ -217,22 +217,22 @@ struct ProblemDescription : ProblemDescriptionBase std::size_t GetInSize() const { - return GetInBatchSize() * GetInChannels() * GetInDepth() * GetInHeight() * GetInWidth() * - GetInElementSize(); + return static_cast(GetInBatchSize_()) * GetInChannels_() * GetInDepth_() * + GetInHeight_() * GetInWidth_() * GetInElementSize(); } // Out getters miopenDataType_t GetOutDataType() const { return out.GetType(); } - std::size_t GetOutBatchSize() const { return GetN5(GetSpatialDims(), out.GetLengths()); } - std::size_t GetOutChannels() const { return GetC5(GetSpatialDims(), out.GetLengths()); } - std::size_t GetOutDepth() const { return GetD5(GetSpatialDims(), out.GetLengths()); } - std::size_t GetOutHeight() const { return GetH5(GetSpatialDims(), out.GetLengths()); } - std::size_t GetOutWidth() const { return GetW5(GetSpatialDims(), out.GetLengths()); } - std::size_t GetOutBatchStride() const { return GetN5(GetSpatialDims(), out.GetStrides()); } - std::size_t GetOutChannelStride() const { return GetC5(GetSpatialDims(), out.GetStrides()); } - std::size_t GetOutStrideD() const { return GetD5(GetSpatialDims(), out.GetStrides()); } - std::size_t GetOutStrideH() const { return GetH5(GetSpatialDims(), out.GetStrides()); } - std::size_t GetOutStrideW() const { return GetW5(GetSpatialDims(), out.GetStrides()); } + unsigned GetOutBatchSize_() const { return GetN5(GetSpatialDims(), out.GetLengths()); } + unsigned GetOutChannels_() const { return GetC5(GetSpatialDims(), out.GetLengths()); } + unsigned GetOutDepth_() const { return GetD5(GetSpatialDims(), out.GetLengths()); } + unsigned GetOutHeight_() const { return GetH5(GetSpatialDims(), out.GetLengths()); } + unsigned GetOutWidth_() const { return GetW5(GetSpatialDims(), out.GetLengths()); } + unsigned GetOutBatchStride_() const { return GetN5(GetSpatialDims(), out.GetStrides()); } + unsigned GetOutChannelStride_() const { return GetC5(GetSpatialDims(), out.GetStrides()); } + unsigned GetOutStrideD_() const { return GetD5(GetSpatialDims(), out.GetStrides()); } + unsigned GetOutStrideH_() const { return GetH5(GetSpatialDims(), out.GetStrides()); } + unsigned GetOutStrideW_() const { return GetW5(GetSpatialDims(), out.GetStrides()); } std::string GetOutLayout() const { return out_layout; } std::string ComputeOutLayout() const { @@ -249,33 +249,30 @@ struct ProblemDescription : ProblemDescriptionBase std::size_t GetOutSize() const { - return GetOutBatchSize() * GetOutChannels() * GetOutDepth() * GetOutHeight() * - GetOutWidth() * GetOutElementSize(); + return static_cast(GetOutBatchSize_()) * GetOutChannels_() * GetOutDepth_() * + GetOutHeight_() * GetOutWidth_() * GetOutElementSize(); } // Weights getters miopenDataType_t GetWeightsDataType() const { return weights.GetType(); } - std::size_t GetWeightsDepth() const { return GetD5(GetSpatialDims(), weights.GetLengths()); } - std::size_t GetWeightsHeight() const + unsigned GetWeightsDepth_() const { return GetD5(GetSpatialDims(), weights.GetLengths()); } + unsigned GetWeightsHeight_() const { if(weights.GetLayout_str() == "CHWNc") return GetHofCHWN(weights.GetLengths()); else return GetH5(GetSpatialDims(), weights.GetLengths()); } - std::size_t GetWeightsWidth() const + unsigned GetWeightsWidth_() const { if(weights.GetLayout_str() == "CHWNc") return GetWofCHWN(weights.GetLengths()); else return GetW5(GetSpatialDims(), weights.GetLengths()); } - // std::size_t GetWeightsStrideD() const { return GetD5(GetSpatialDims(), weights.GetStrides()); - // } - // std::size_t GetWeightsStrideH() const { return GetH5(GetSpatialDims(), weights.GetStrides()); - // } - // std::size_t GetWeightsStrideW() const { return GetW5(GetSpatialDims(), weights.GetStrides()); - // } + // unsigned GetWeightsStrideD() const { return GetD5(GetSpatialDims(), weights.GetStrides()); } + // unsigned GetWeightsStrideH() const { return GetH5(GetSpatialDims(), weights.GetStrides()); } + // unsigned GetWeightsStrideW() const { return GetW5(GetSpatialDims(), weights.GetStrides()); } std::string GetWeightsLayout() const { return weights_layout; } std::string ComputeWeightsLayout() const { @@ -292,8 +289,8 @@ struct ProblemDescription : ProblemDescriptionBase std::size_t GetWeightsSize() const { - return GetInChannels() * GetOutChannels() * GetWeightsDepth() * GetWeightsHeight() * - GetWeightsWidth() * GetWeightsElementSize(); + return static_cast(GetInChannels_()) * GetOutChannels_() * GetWeightsDepth_() * + GetWeightsHeight_() * GetWeightsWidth_() * GetWeightsElementSize(); } const TensorDescriptor& GetIn() const { return in; } @@ -308,19 +305,19 @@ struct ProblemDescription : ProblemDescriptionBase std::size_t GetBiasSize() const { - return (GetBias() != 0) ? (GetOutChannels() * GetOutElementSize()) : 0; + return (GetBias() != 0) ? (GetOutChannels_() * GetOutElementSize()) : 0; } - int GetBackwardPadW() const { return static_cast(GetWeightsWidth()) - GetPadW() - 1; } - int GetBackwardPadH() const { return static_cast(GetWeightsHeight()) - GetPadH() - 1; } + int GetBackwardPadW() const { return static_cast(GetWeightsWidth_()) - GetPadW() - 1; } + int GetBackwardPadH() const { return static_cast(GetWeightsHeight_()) - GetPadH() - 1; } bool IsAsymmetricPadH() const { - return conv.paddingMode == miopenPaddingSame && (GetWeightsHeight() % 2) == 0; + return conv.paddingMode == miopenPaddingSame && (GetWeightsHeight_() % 2) == 0; } bool IsAsymmetricPadW() const { - return conv.paddingMode == miopenPaddingSame && (GetWeightsWidth() % 2) == 0; + return conv.paddingMode == miopenPaddingSame && (GetWeightsWidth_() % 2) == 0; } bool Is2d() const { return GetSpatialDims() == 2; } @@ -395,15 +392,15 @@ struct ProblemDescription : ProblemDescriptionBase { // The column names match the driver command line argument names f(self.GetSpatialDims(), "spatial_dim"); - f(self.GetInChannels(), "in_channels"); - f(self.GetInHeight(), "in_h"); - f(self.GetInWidth(), "in_w"); - f(self.GetInDepth(), "in_d"); - f(self.GetWeightsHeight(), "fil_h"); - f(self.GetWeightsWidth(), "fil_w"); - f(self.GetWeightsDepth(), "fil_d"); - f(self.GetOutChannels(), "out_channels"); - f(self.GetBatchSize(), "batchsize"); + f(self.GetInChannels_(), "in_channels"); + f(self.GetInHeight_(), "in_h"); + f(self.GetInWidth_(), "in_w"); + f(self.GetInDepth_(), "in_d"); + f(self.GetWeightsHeight_(), "fil_h"); + f(self.GetWeightsWidth_(), "fil_w"); + f(self.GetWeightsDepth_(), "fil_d"); + f(self.GetOutChannels_(), "out_channels"); + f(self.GetBatchSize_(), "batchsize"); f(self.GetPadH(), "pad_h"); f(self.GetPadW(), "pad_w"); f(self.GetPadD(), "pad_d"); diff --git a/src/include/miopen/conv/solver_finders.hpp b/src/include/miopen/conv/solver_finders.hpp index 7494f214aa..69425f09a7 100644 --- a/src/include/miopen/conv/solver_finders.hpp +++ b/src/include/miopen/conv/solver_finders.hpp @@ -49,16 +49,15 @@ class SolversFinder const AnyInvokeParams& invoke_ctx, bool use_winograd_only) const { - if(!IsEnabled(ctx, problem.conv_problem, use_winograd_only)) + if(!IsEnabled(ctx, problem, use_winograd_only)) { - MIOPEN_LOG_I2("Skipping " << GetAlgorithmName(problem.conv_problem).ToString()); + MIOPEN_LOG_I2("Skipping " << GetAlgorithmName(problem).ToString()); return {}; } try { - MIOPEN_LOG_I2("Starting find for " - << GetAlgorithmName(problem.conv_problem).ToString()); + MIOPEN_LOG_I2("Starting find for " << GetAlgorithmName(problem).ToString()); return FindImpl(ctx, problem, invoke_ctx, use_winograd_only); } catch(Exception& ex) diff --git a/src/include/miopen/fusion/context.hpp b/src/include/miopen/fusion/context.hpp index 4cc948c091..cfa38f36f5 100644 --- a/src/include/miopen/fusion/context.hpp +++ b/src/include/miopen/fusion/context.hpp @@ -35,7 +35,7 @@ struct FusionContext : miopen::ExecutionContext ConvolutionContext GetConvContext(const miopen::ProblemDescription& conv_problem) const { auto ctx = ConvolutionContext{*this}; - conv_problem.conv_problem.SetupFloats(ctx); + conv_problem.SetupFloats(ctx); return ctx; } diff --git a/src/include/miopen/fusion/problem_description.hpp b/src/include/miopen/fusion/problem_description.hpp index 7403d008ae..d3b6689764 100644 --- a/src/include/miopen/fusion/problem_description.hpp +++ b/src/include/miopen/fusion/problem_description.hpp @@ -46,7 +46,7 @@ struct FusionDescription if(op->kind() == miopenFusionOpConvForward) { const auto prob = GetConvProblem(op->GetIdx(), conv::Direction::Forward); - net_config << prob.conv_problem.BuildConfKey().ToString(); + net_config << prob.BuildConfKey().ToString(); } else if(op->kind() == miopenFusionOpBatchNormInference) { diff --git a/src/include/miopen/mlo_internal.hpp b/src/include/miopen/mlo_internal.hpp index 22cc7bade4..c3a00fc3ee 100644 --- a/src/include/miopen/mlo_internal.hpp +++ b/src/include/miopen/mlo_internal.hpp @@ -178,9 +178,6 @@ auto mloConstruct(T& x) -> decltype(x.mloConstruct(), void()) x.mloConstruct(); } -bool IsGemmAplicable(const miopen::ConvolutionContext& ctx, - const miopen::ProblemDescription& problem); - std::vector FindAllGemmSolutions(const miopen::ConvolutionContext& ctx, const miopen::ProblemDescription& problem, diff --git a/src/include/miopen/problem_description.hpp b/src/include/miopen/problem_description.hpp index 52a5a2b74b..bc781d4b1e 100644 --- a/src/include/miopen/problem_description.hpp +++ b/src/include/miopen/problem_description.hpp @@ -38,6 +38,8 @@ #include #include +#define FIN_OLD_PROBLEM_DESCRIPTION_COMPAT 1 + namespace miopen { // Tensor Helper APIs @@ -63,73 +65,8 @@ SetDescFromMLDesc(int spatial_dims, TTo& to, const TensorDescriptor& tensor, con struct ConvolutionDescriptor; // Todo: change all uses in convolution to conv::ProblemDescription and remove this -struct ProblemDescription -#if MIOPEN_ENABLE_SQLITE - : SQLiteSerializable -#endif +struct ProblemDescription : conv::ProblemDescription { - conv::ProblemDescription conv_problem; - - int GetSpatialDims() const { return conv_problem.GetSpatialDims(); } - int GetInChannels() const { return conv_problem.GetInChannels(); } - int GetInHeight() const { return conv_problem.GetInHeight(); } - int GetInWidth() const { return conv_problem.GetInWidth(); } - int GetInDepth() const { return conv_problem.GetInDepth(); } - int GetVectorLength() const { return conv_problem.GetVectorLength(); } - int GetWeightsHeight() const { return conv_problem.GetWeightsHeight(); } - int GetWeightsWidth() const { return conv_problem.GetWeightsWidth(); } - int GetWeightsDepth() const { return conv_problem.GetWeightsDepth(); } - int GetOutChannels() const { return conv_problem.GetOutChannels(); } - int GetOutHeight() const { return conv_problem.GetOutHeight(); } - int GetOutWidth() const { return conv_problem.GetOutWidth(); } - int GetOutDepth() const { return conv_problem.GetOutDepth(); } - int GetBatchSize() const { return conv_problem.GetBatchSize(); } - int GetPadH() const { return conv_problem.GetPadH(); } - int GetPadW() const { return conv_problem.GetPadW(); } - int GetPadD() const { return conv_problem.GetPadD(); } - int GetKernelStrideH() const { return conv_problem.GetKernelStrideH(); } - int GetKernelStrideW() const { return conv_problem.GetKernelStrideW(); } - int GetKernelStrideD() const { return conv_problem.GetKernelStrideD(); } - int GetDilationH() const { return conv_problem.GetDilationH(); } - int GetDilationW() const { return conv_problem.GetDilationW(); } - int GetDilationD() const { return conv_problem.GetDilationD(); } - int GetBias() const { return conv_problem.GetBias(); } - std::string GetInLayout() const { return conv_problem.GetInLayout(); } - std::string GetWeightsLayout() const { return conv_problem.GetWeightsLayout(); } - std::string GetOutLayout() const { return conv_problem.GetOutLayout(); } - miopenDataType_t GetInDataType() const { return conv_problem.GetInDataType(); } - miopenDataType_t GetWeightsDataType() const { return conv_problem.GetWeightsDataType(); } - miopenDataType_t GetOutDataType() const { return conv_problem.GetOutDataType(); } - size_t GetInSize() const { return conv_problem.GetInSize(); } - size_t GetOutSize() const { return conv_problem.GetOutSize(); } - size_t GetWeightsSize() const { return conv_problem.GetWeightsSize(); } - size_t GetBiasSize() const { return conv_problem.GetBiasSize(); } - int GetInStride() const { return conv_problem.GetInStrideH(); } - int GetOutStride() const { return conv_problem.GetOutStrideH(); } - int GetInChannelStride() const { return conv_problem.GetInChannelStride(); } - int GetInBatchStride() const { return conv_problem.GetInBatchStride(); } - int GetOutChannelStride() const { return conv_problem.GetOutChannelStride(); } - int GetOutBatchStride() const { return conv_problem.GetOutBatchStride(); } - int GetGroupCount() const { return conv_problem.GetGroupCount(); } - -#if MIOPEN_ENABLE_SQLITE - static std::string table_name() { return conv::ProblemDescription::table_name(); } -#endif - - bool IsLayoutDefault() const; - - bool IsLayoutNHWC() const; - - bool IsLayoutNCHWc() const; - -#if MIOPEN_ENABLE_SQLITE - template - static void Visit(Self&& self, F f) - { - conv::ProblemDescription::Visit(self, f); - } -#endif - struct Direction { public: @@ -146,43 +83,20 @@ struct ProblemDescription conv::Direction v = conv::Direction::Forward; } direction; - std::string GetDirectionStr() const { return direction.GetStr(); } - - int GetBackwardPadW() const { return conv_problem.GetBackwardPadW(); } - int GetBackwardPadH() const { return conv_problem.GetBackwardPadH(); } - - bool IsAsymmetricPadH() const { return conv_problem.IsAsymmetricPadH(); } - bool IsAsymmetricPadW() const { return conv_problem.IsAsymmetricPadW(); } - - bool Is2d() const { return conv_problem.Is2d(); } - bool Is3d() const { return conv_problem.Is3d(); } - - bool IsFp32() const { return conv_problem.IsFp32(); } - bool IsFp16() const { return conv_problem.IsFp16(); } - bool IsBfp16() const { return conv_problem.IsBfp16(); } - bool IsInt8() const { return conv_problem.IsInt8(); } - - bool IsNCHWc_NCHWc() const { return conv_problem.IsNCHWc_NCHWc(); } - - bool IsNCHWc_CHWNc() const { return conv_problem.IsNCHWc_CHWNc(); } - ProblemDescription() = default; ProblemDescription(conv::ProblemDescription desc); - void Serialize(std::ostream& stream) const; - - friend std::ostream& operator<<(std::ostream& os, const ProblemDescription& obj) +#if FIN_OLD_PROBLEM_DESCRIPTION_COMPAT + struct { - obj.Serialize(os); - return os; - } - - void BuildConfKey(std::string& conf_key) const; + void SetupFloats(ExecutionContext& ctx) const { p->SetupFloats(ctx); } - NetworkConfig BuildConfKey() const { return conv_problem.BuildConfKey(); } - - void SetupFloats(ExecutionContext& ctx) const { conv_problem.SetupFloats(ctx); }; + private: + const conv::ProblemDescription* p = nullptr; + friend struct ProblemDescription; + } conv_problem; +#endif }; // For mlo_construct_base @@ -387,19 +301,19 @@ struct UnifiedDescriptionConv2d if(!problem.Is2d()) MIOPEN_THROW(miopenStatusInternalError, "UnifiedDescriptionConv2d supports only 2D"); - const auto n_inputs_per_group = problem.GetInChannels() / problem.GetGroupCount(); - const auto n_outputs_per_group = problem.GetOutChannels() / problem.GetGroupCount(); + const auto n_inputs_per_group = problem.GetInChannels_() / problem.GetGroupCount(); + const auto n_outputs_per_group = problem.GetOutChannels_() / problem.GetGroupCount(); if(!problem.direction.IsBackwardWrW()) { - R = problem.GetWeightsHeight(); - S = problem.GetWeightsWidth(); + R = problem.GetWeightsHeight_(); + S = problem.GetWeightsWidth_(); U = problem.direction.IsForward() ? problem.GetKernelStrideH() : 1; V = problem.direction.IsForward() ? problem.GetKernelStrideW() : 1; - C = n_inputs_per_group; // Bwd: C and K is reversed in ProblemDescription. - K = n_outputs_per_group; // Ditto. - out_h = problem.GetOutHeight(); // Bwd: height/width is reversed in ProblemDescription. - out_w = problem.GetOutWidth(); // Ditto. - N = problem.GetBatchSize(); + C = n_inputs_per_group; // Bwd: C and K is reversed in ProblemDescription. + K = n_outputs_per_group; // Ditto. + out_h = problem.GetOutHeight_(); // Bwd: height/width is reversed in ProblemDescription. + out_w = problem.GetOutWidth_(); // Ditto. + N = problem.GetBatchSize_(); pad_h = problem.direction.IsForward() ? problem.GetPadH() : problem.GetBackwardPadH(); pad_w = problem.direction.IsForward() ? problem.GetPadW() : problem.GetBackwardPadW(); input_stride_h = problem.direction.IsForward() ? 1 : problem.GetKernelStrideH(); @@ -409,14 +323,14 @@ struct UnifiedDescriptionConv2d } else { // WrW - R = problem.GetInHeight(); - S = problem.GetInWidth(); + R = problem.GetInHeight_(); + S = problem.GetInWidth_(); U = problem.GetDilationH(); V = problem.GetDilationW(); - C = problem.GetBatchSize(); + C = problem.GetBatchSize_(); K = n_inputs_per_group; - out_h = problem.GetWeightsHeight(); - out_w = problem.GetWeightsWidth(); + out_h = problem.GetWeightsHeight_(); + out_w = problem.GetWeightsWidth_(); N = n_outputs_per_group; pad_h = problem.GetPadH(); pad_w = problem.GetPadW(); diff --git a/src/include/miopen/solver.hpp b/src/include/miopen/solver.hpp index 2115987769..811073fd4b 100644 --- a/src/include/miopen/solver.hpp +++ b/src/include/miopen/solver.hpp @@ -3063,7 +3063,7 @@ struct GemmFwdBase : ConvSolver bool IsDynamic() const override { return true; } float GetWti(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWti(static_cast(ctx), problem.conv_problem); + return GetWti(static_cast(ctx), problem); } private: @@ -3086,7 +3086,7 @@ struct GemmFwd1x1_0_2 final : GemmFwdBase size_t GetWorkspaceSize(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWorkspaceSize(static_cast(ctx), problem.conv_problem); + return GetWorkspaceSize(static_cast(ctx), problem); } bool MayNeedWorkspace() const override { return true; } @@ -3094,13 +3094,13 @@ struct GemmFwd1x1_0_2 final : GemmFwdBase bool IsApplicable(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return IsApplicable(static_cast(ctx), problem.conv_problem); + return IsApplicable(static_cast(ctx), problem); } ConvSolution GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetSolution(static_cast(ctx), problem.conv_problem); + return GetSolution(static_cast(ctx), problem); } private: @@ -3121,7 +3121,7 @@ struct GemmFwd1x1_0_1_int8 final : GemmFwdBase size_t GetWorkspaceSize(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWorkspaceSize(static_cast(ctx), problem.conv_problem); + return GetWorkspaceSize(static_cast(ctx), problem); } bool MayNeedWorkspace() const override { return true; } @@ -3129,13 +3129,13 @@ struct GemmFwd1x1_0_1_int8 final : GemmFwdBase bool IsApplicable(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return IsApplicable(static_cast(ctx), problem.conv_problem); + return IsApplicable(static_cast(ctx), problem); } ConvSolution GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetSolution(static_cast(ctx), problem.conv_problem); + return GetSolution(static_cast(ctx), problem); } private: @@ -3156,7 +3156,7 @@ struct GemmFwd1x1_0_1 final : GemmFwdBase size_t GetWorkspaceSize(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWorkspaceSize(static_cast(ctx), problem.conv_problem); + return GetWorkspaceSize(static_cast(ctx), problem); } bool MayNeedWorkspace() const override { return true; } @@ -3164,13 +3164,13 @@ struct GemmFwd1x1_0_1 final : GemmFwdBase bool IsApplicable(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return IsApplicable(static_cast(ctx), problem.conv_problem); + return IsApplicable(static_cast(ctx), problem); } ConvSolution GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetSolution(static_cast(ctx), problem.conv_problem); + return GetSolution(static_cast(ctx), problem); } private: @@ -3191,7 +3191,7 @@ struct GemmFwdRest final : GemmFwdBase size_t GetWorkspaceSize(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWorkspaceSize(static_cast(ctx), problem.conv_problem); + return GetWorkspaceSize(static_cast(ctx), problem); } bool MayNeedWorkspace() const override { return true; } @@ -3199,13 +3199,13 @@ struct GemmFwdRest final : GemmFwdBase bool IsApplicable(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return IsApplicable(static_cast(ctx), problem.conv_problem); + return IsApplicable(static_cast(ctx), problem); } ConvSolution GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetSolution(static_cast(ctx), problem.conv_problem); + return GetSolution(static_cast(ctx), problem); } private: @@ -3223,7 +3223,7 @@ struct GemmBwdBase : ConvSolver bool IsDynamic() const override { return true; } float GetWti(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWti(static_cast(ctx), problem.conv_problem); + return GetWti(static_cast(ctx), problem); } private: @@ -3246,7 +3246,7 @@ struct GemmBwd1x1_stride2 final : GemmBwdBase size_t GetWorkspaceSize(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWorkspaceSize(static_cast(ctx), problem.conv_problem); + return GetWorkspaceSize(static_cast(ctx), problem); } bool MayNeedWorkspace() const override { return true; } @@ -3254,13 +3254,13 @@ struct GemmBwd1x1_stride2 final : GemmBwdBase bool IsApplicable(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return IsApplicable(static_cast(ctx), problem.conv_problem); + return IsApplicable(static_cast(ctx), problem); } ConvSolution GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetSolution(static_cast(ctx), problem.conv_problem); + return GetSolution(static_cast(ctx), problem); } private: @@ -3281,7 +3281,7 @@ struct GemmBwd1x1_stride1 final : GemmBwdBase size_t GetWorkspaceSize(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWorkspaceSize(static_cast(ctx), problem.conv_problem); + return GetWorkspaceSize(static_cast(ctx), problem); } bool MayNeedWorkspace() const override { return true; } @@ -3289,13 +3289,13 @@ struct GemmBwd1x1_stride1 final : GemmBwdBase bool IsApplicable(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return IsApplicable(static_cast(ctx), problem.conv_problem); + return IsApplicable(static_cast(ctx), problem); } ConvSolution GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetSolution(static_cast(ctx), problem.conv_problem); + return GetSolution(static_cast(ctx), problem); } private: @@ -3318,7 +3318,7 @@ struct GemmBwdRest final : GemmBwdBase size_t GetWorkspaceSize(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWorkspaceSize(static_cast(ctx), problem.conv_problem); + return GetWorkspaceSize(static_cast(ctx), problem); } bool MayNeedWorkspace() const override { return true; } @@ -3326,13 +3326,13 @@ struct GemmBwdRest final : GemmBwdBase bool IsApplicable(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return IsApplicable(static_cast(ctx), problem.conv_problem); + return IsApplicable(static_cast(ctx), problem); } ConvSolution GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetSolution(static_cast(ctx), problem.conv_problem); + return GetSolution(static_cast(ctx), problem); } private: @@ -3350,7 +3350,7 @@ struct GemmWrwBase : ConvSolver bool IsDynamic() const override { return true; } float GetWti(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWti(static_cast(ctx), problem.conv_problem); + return GetWti(static_cast(ctx), problem); } private: @@ -3371,13 +3371,13 @@ struct GemmWrw1x1_stride1 final : GemmWrwBase bool IsApplicable(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return IsApplicable(static_cast(ctx), problem.conv_problem); + return IsApplicable(static_cast(ctx), problem); } ConvSolution GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetSolution(static_cast(ctx), problem.conv_problem); + return GetSolution(static_cast(ctx), problem); } private: @@ -3397,7 +3397,7 @@ struct GemmWrwUniversal final : GemmWrwBase size_t GetWorkspaceSize(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetWorkspaceSize(static_cast(ctx), problem.conv_problem); + return GetWorkspaceSize(static_cast(ctx), problem); } bool MayNeedWorkspace() const override { return true; } @@ -3405,13 +3405,13 @@ struct GemmWrwUniversal final : GemmWrwBase bool IsApplicable(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return IsApplicable(static_cast(ctx), problem.conv_problem); + return IsApplicable(static_cast(ctx), problem); } ConvSolution GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const override { - return GetSolution(static_cast(ctx), problem.conv_problem); + return GetSolution(static_cast(ctx), problem); } private: @@ -4886,8 +4886,6 @@ struct ConvHipImplicitGemm3DGroupFwdXdlops final bool CheckCKApplicability(const ProblemDescription&) const; }; -struct AnySolver; - // Use struct as a syntactic sugar to make the intent as clear as possible. struct ThisSolverIsDeprecatedStatic { diff --git a/src/include/miopen/solver/implicitgemm_util.hpp b/src/include/miopen/solver/implicitgemm_util.hpp index b7d031fab3..88262b4a32 100644 --- a/src/include/miopen/solver/implicitgemm_util.hpp +++ b/src/include/miopen/solver/implicitgemm_util.hpp @@ -95,61 +95,61 @@ static inline std::size_t KernelFilterDilationW(const ProblemDescription& proble static inline std::size_t KernelOutputChannelK(const ProblemDescription& problem) { if(problem.direction.IsBackwardWrW()) - return problem.GetInChannels(); + return problem.GetInChannels_(); else - return problem.GetOutChannels(); + return problem.GetOutChannels_(); } static inline std::size_t KernelInputChannelC(const ProblemDescription& problem) { if(problem.direction.IsBackwardWrW()) - return problem.GetBatchSize(); + return problem.GetBatchSize_(); else - return problem.GetInChannels() / problem.GetGroupCount(); + return problem.GetInChannels_() / problem.GetGroupCount(); } static inline std::size_t KernelBatchN(const ProblemDescription& problem) { if(problem.direction.IsBackwardWrW()) - return problem.GetOutChannels() / problem.GetGroupCount(); + return problem.GetOutChannels_() / problem.GetGroupCount(); else - return problem.GetBatchSize(); + return problem.GetBatchSize_(); } static inline std::size_t KernelOutputHeightHo(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetOutHeight(); + return problem.GetOutHeight_(); else if(problem.direction.IsBackwardWrW()) - return problem.GetWeightsHeight(); + return problem.GetWeightsHeight_(); else - return problem.GetInHeight(); + return problem.GetInHeight_(); } static inline std::size_t KernelOutputWidthWo(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetOutWidth(); + return problem.GetOutWidth_(); else if(problem.direction.IsBackwardWrW()) - return problem.GetWeightsWidth(); + return problem.GetWeightsWidth_(); else - return problem.GetInWidth(); + return problem.GetInWidth_(); } static inline std::size_t KernelFilterWidthX(const ProblemDescription& problem) { if(problem.direction.IsBackwardWrW()) - return problem.GetInWidth(); + return problem.GetInWidth_(); else - return problem.GetWeightsWidth(); + return problem.GetWeightsWidth_(); } static inline std::size_t KernelFilterHeightY(const ProblemDescription& problem) { if(problem.direction.IsBackwardWrW()) - return problem.GetInHeight(); + return problem.GetInHeight_(); else - return problem.GetWeightsHeight(); + return problem.GetWeightsHeight_(); } /// \todo move to separate header and use in other solvers. diff --git a/src/include/miopen/solver/problem_description_interpreter.hpp b/src/include/miopen/solver/problem_description_interpreter.hpp index 2cb3633b41..3e9e7fb3de 100644 --- a/src/include/miopen/solver/problem_description_interpreter.hpp +++ b/src/include/miopen/solver/problem_description_interpreter.hpp @@ -47,7 +47,7 @@ struct ProblemInterpreter return problem.GetGroupCount(); } - static auto GetBatchN(const ProblemDescription& problem) { return problem.GetBatchSize(); } + static int GetBatchN(const ProblemDescription& problem) { return problem.GetBatchSize_(); } static auto GetOutputLayout(const ProblemDescription& problem) { @@ -57,12 +57,12 @@ struct ProblemInterpreter return problem.GetInLayout(); } - static auto GetOutputChannelK(const ProblemDescription& problem) + static int GetOutputChannelK(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetOutChannels(); + return problem.GetOutChannels_(); else - return problem.GetInChannels(); + return problem.GetInChannels_(); } static auto GetInputLayout(const ProblemDescription& problem) @@ -73,60 +73,60 @@ struct ProblemInterpreter return problem.GetOutLayout(); } - static auto GetInputChannelC(const ProblemDescription& problem) + static int GetInputChannelC(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetInChannels(); + return problem.GetInChannels_(); else - return problem.GetOutChannels(); + return problem.GetOutChannels_(); } - static auto GetInputDepthDi(const ProblemDescription& problem) + static int GetInputDepthDi(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetInDepth(); + return problem.GetInDepth_(); else - return problem.GetOutDepth(); + return problem.GetOutDepth_(); } - static auto GetInputHeightHi(const ProblemDescription& problem) + static int GetInputHeightHi(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetInHeight(); + return problem.GetInHeight_(); else - return problem.GetOutHeight(); + return problem.GetOutHeight_(); } - static auto GetInputWidthWi(const ProblemDescription& problem) + static int GetInputWidthWi(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetInWidth(); + return problem.GetInWidth_(); else - return problem.GetOutWidth(); + return problem.GetOutWidth_(); } - static auto GetOutputDepthDo(const ProblemDescription& problem) + static int GetOutputDepthDo(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetOutDepth(); + return problem.GetOutDepth_(); else - return problem.GetInDepth(); + return problem.GetInDepth_(); } - static auto GetOutputHeightHo(const ProblemDescription& problem) + static int GetOutputHeightHo(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetOutHeight(); + return problem.GetOutHeight_(); else - return problem.GetInHeight(); + return problem.GetInHeight_(); } - static auto GetOutputWidthWo(const ProblemDescription& problem) + static int GetOutputWidthWo(const ProblemDescription& problem) { if(problem.direction.IsForward()) - return problem.GetOutWidth(); + return problem.GetOutWidth_(); else - return problem.GetInWidth(); + return problem.GetInWidth_(); } static auto GetOutputDataType(const ProblemDescription& problem) @@ -139,9 +139,9 @@ struct ProblemInterpreter return problem.direction.IsForward() ? problem.GetInDataType() : problem.GetOutDataType(); } - static auto GetFilterDepthZ(const ProblemDescription& problem) + static int GetFilterDepthZ(const ProblemDescription& problem) { - return problem.GetWeightsDepth(); + return problem.GetWeightsDepth_(); } static auto GetFilterLayout(const ProblemDescription& problem) @@ -149,14 +149,14 @@ struct ProblemInterpreter return problem.GetWeightsLayout(); } - static auto GetFilterHeightY(const ProblemDescription& problem) + static int GetFilterHeightY(const ProblemDescription& problem) { - return problem.GetWeightsHeight(); + return problem.GetWeightsHeight_(); } - static auto GetFilterWidthX(const ProblemDescription& problem) + static int GetFilterWidthX(const ProblemDescription& problem) { - return problem.GetWeightsWidth(); + return problem.GetWeightsWidth_(); } // adjust conv_stride_d to 1 if Do is 1 diff --git a/src/include/miopen/tensor.hpp b/src/include/miopen/tensor.hpp index 093c63119c..e27622bc4a 100644 --- a/src/include/miopen/tensor.hpp +++ b/src/include/miopen/tensor.hpp @@ -285,7 +285,7 @@ struct TensorDescriptor : miopenTensorDescriptor }; template -constexpr auto GetNCDHW(int spatial_dims, const std::vector& data) +constexpr auto GetNCDHW(unsigned spatial_dims, const std::vector& data) { if(spatial_dims == 3) return miopen::tien<5>(data, 1); diff --git a/src/mlo_dir_conv.cpp b/src/mlo_dir_conv.cpp index 7ebb26d615..2f44a7d327 100644 --- a/src/mlo_dir_conv.cpp +++ b/src/mlo_dir_conv.cpp @@ -201,12 +201,6 @@ static auto GetBwdWrW2DSolvers() static auto GetFFTSolvers() { return miopen::solver::SolverContainer{}; } -bool IsGemmAplicable(const miopen::ConvolutionContext& ctx, - const miopen::ProblemDescription& problem) -{ - return GetGemmSolvers().IsAnySolverApplicable(ctx, problem); -} - std::vector FindAllGemmSolutions(const miopen::ConvolutionContext& ctx, const miopen::ProblemDescription& problem, diff --git a/src/problem_description.cpp b/src/problem_description.cpp index 6b31f76795..3d05031bbc 100644 --- a/src/problem_description.cpp +++ b/src/problem_description.cpp @@ -1,32 +1,13 @@ #include -#include - -#include -#include -#include - namespace miopen { -void ProblemDescription::BuildConfKey(std::string& conf_key) const -{ - conv_problem.BuildConfKey(conf_key); -} - -bool ProblemDescription::IsLayoutDefault() const { return conv_problem.IsLayoutDefault(); } - -bool ProblemDescription::IsLayoutNHWC() const { return conv_problem.IsLayoutNHWC(); } - -bool ProblemDescription::IsLayoutNCHWc() const { return conv_problem.IsLayoutNCHWc(); } - -void ProblemDescription::Serialize(std::ostream& stream) const -{ - return conv_problem.Serialize(stream); -} - ProblemDescription::ProblemDescription(conv::ProblemDescription desc) - : conv_problem(std::move(desc)), direction(conv_problem.GetDirection()) + : conv::ProblemDescription(std::move(desc)), direction(GetDirection()) { +#if FIN_OLD_PROBLEM_DESCRIPTION_COMPAT + conv_problem.p = this; +#endif } } // namespace miopen diff --git a/src/solver/conv_MP_bidirectional_winograd.cpp b/src/solver/conv_MP_bidirectional_winograd.cpp index 36859be3cb..fc2d769520 100644 --- a/src/solver/conv_MP_bidirectional_winograd.cpp +++ b/src/solver/conv_MP_bidirectional_winograd.cpp @@ -89,17 +89,17 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_CONV_PRECISE_ROCBLAS_TIMING) solver::ConvMPBidirectWinograd:: \ GetSolverWinoXformHWSize(); -#define DEFINE_SHADER_ALIASES(problem) \ - const auto group_cnt = (problem).GetGroupCount(); \ - const auto N = (problem).GetBatchSize(); \ - const int K = (problem).GetOutChannels() / group_cnt; \ - const int C = (problem).GetInChannels() / group_cnt; \ - const auto R = (problem).GetWeightsHeight(); \ - const auto S = (problem).GetWeightsWidth(); \ - const auto H = (problem).GetInHeight(); \ - const auto W = (problem).GetInWidth(); \ - const auto out_H = (problem).GetOutHeight(); \ - const auto out_W = (problem).GetOutWidth(); +#define DEFINE_SHADER_ALIASES(problem) \ + const auto group_cnt = (problem).GetGroupCount(); \ + const int N = (problem).GetBatchSize_(); \ + const int K = (problem).GetOutChannels_() / group_cnt; \ + const int C = (problem).GetInChannels_() / group_cnt; \ + const int R = (problem).GetWeightsHeight_(); \ + const int S = (problem).GetWeightsWidth_(); \ + const int H = (problem).GetInHeight_(); \ + const int W = (problem).GetInWidth_(); \ + const int out_H = (problem).GetOutHeight_(); \ + const int out_W = (problem).GetOutWidth_(); #if MIOPEN_BACKEND_HIP #define GENERATE_MAIN_OPTIONS(options) \ @@ -294,8 +294,8 @@ static bool IsApplicableTransform(const ConvolutionContext& ctx, const ProblemDe // clang-format off bool ok = ( - (problem.GetWeightsWidth() == WinoFilterW - && problem.GetWeightsHeight() == WinoFilterH) + (problem.GetWeightsWidth_() == WinoFilterW + && problem.GetWeightsHeight_() == WinoFilterH) && (problem.GetKernelStrideW() == 1) && problem.GetKernelStrideH() == problem.GetKernelStrideW() && problem.GetDilationW() == 1 @@ -457,7 +457,7 @@ static InvokerFactory MakeWinogradInvokerFactory(const ConvolutionContext& ctx, // clang-format off GemmDescriptor wino_gemm_desc{isColMajor,transA,transB,m,n,k, lda,ldb,ldc,batch_count,strideA,strideB, - strideC,alpha,beta,transform_data_type, problem.conv_problem.GetConv().attribute.deterministic }; + strideC,alpha,beta,transform_data_type, problem.GetConv().attribute.deterministic }; // clang-format on #else (void)wino_xform_w; @@ -729,7 +729,7 @@ ConvolutionContext ConvMPBidirectWinograd_xdlops(ctx)}; - transformed_problem.conv_problem.SetupFloats(transformed_ctx); + transformed_problem.SetupFloats(transformed_ctx); return transformed_ctx; } diff --git a/src/solver/conv_asm_1x1u.cpp b/src/solver/conv_asm_1x1u.cpp index f415955c8c..b71d195667 100644 --- a/src/solver/conv_asm_1x1u.cpp +++ b/src/solver/conv_asm_1x1u.cpp @@ -67,12 +67,12 @@ static inline bool UseUpsample(const ProblemDescription& problem) /// out_height/out_width and vice versa. static inline int AsmImgHeight(const ProblemDescription& problem) { - return UseSubsample(problem) ? problem.GetOutHeight() : problem.GetInHeight(); + return UseSubsample(problem) ? problem.GetOutHeight_() : problem.GetInHeight_(); } static inline int AsmImgWidth(const ProblemDescription& problem) { - return UseSubsample(problem) ? problem.GetOutWidth() : problem.GetInWidth(); + return UseSubsample(problem) ? problem.GetOutWidth_() : problem.GetInWidth_(); } /// \todo move to separate header and use in other solvers. @@ -295,14 +295,14 @@ bool PerformanceConfigConvAsm1x1U::IsValidImpl(const ProblemDescription& problem const auto elements_in_dword = 4 / static_cast(GetTypeSize(problem.GetInDataType())); if(elements_in_dword == 0) // For clang-tidy (DIV/0) MIOPEN_THROW(miopenStatusInternalError); - const auto img_hw = problem.GetOutHeight() * problem.GetOutWidth(); + const auto img_hw = problem.GetOutHeight_() * problem.GetOutWidth_(); if(!IsValidValueImpl(sequence_length)) return false; if(sequence_length > 1) { if((k_mult % elements_in_dword) != 0) return false; - if(problem.direction.IsBackwardData() && !(problem.GetOutChannels() % k_mult == 0)) + if(problem.direction.IsBackwardData() && !(problem.GetOutChannels_() % k_mult == 0)) return false; } if(sequence_length > 2) @@ -320,7 +320,7 @@ bool PerformanceConfigConvAsm1x1U::IsValidImpl(const ProblemDescription& problem } if(sequence_length > 4) { - const int total_n_blocks = (problem.GetBatchSize() + GetNPerGpr() - 1) / GetNPerGpr(); + const int total_n_blocks = (problem.GetBatchSize_() + GetNPerGpr() - 1) / GetNPerGpr(); if(!(n_mult <= total_n_blocks)) return false; } @@ -344,16 +344,17 @@ bool PerformanceConfigConvAsm1x1U::IsValidImpl(const ProblemDescription& problem } if(sequence_length > 6) { - if(!(waves_c_in_group <= problem.GetInChannels())) + if(!(waves_c_in_group <= problem.GetInChannels_())) return false; - const int c_per_wave = (problem.GetInChannels() + waves_c_in_group - 1) / waves_c_in_group; - const int c_per_last_wave = problem.GetInChannels() - (c_per_wave * (waves_c_in_group - 1)); + const int c_per_wave = (problem.GetInChannels_() + waves_c_in_group - 1) / waves_c_in_group; + const int c_per_last_wave = + problem.GetInChannels_() - (c_per_wave * (waves_c_in_group - 1)); if(c_per_wave % c_mult != 0 || c_per_last_wave % c_mult != 0) return false; } if(sequence_length > 7) { - if(!(k_mult * waves_k_in_group <= problem.GetOutChannels())) + if(!(k_mult * waves_k_in_group <= problem.GetOutChannels_())) return false; if(!(waves_c_in_group * waves_k_in_group <= 16)) return false; @@ -405,12 +406,12 @@ static std::vector TransformFeatures(const ProblemDescription& problem, s int offset = (problem.direction.IsForward() ? 0 : 1) + 1; features[(offset)*n + offset] = 1.0; features[3 * n + 3] = - float(problem.direction.IsForward() ? problem.GetInChannels() : problem.GetOutChannels()); + float(problem.direction.IsForward() ? problem.GetInChannels_() : problem.GetOutChannels_()); features[4 * n + 4] = - float(problem.direction.IsForward() ? problem.GetOutChannels() : problem.GetInChannels()); - features[5 * n + 5] = float(problem.GetInHeight()); - features[6 * n + 6] = float(problem.GetInWidth()); - features[7 * n + 7] = float(problem.GetBatchSize()); + float(problem.direction.IsForward() ? problem.GetOutChannels_() : problem.GetInChannels_()); + features[5 * n + 5] = float(problem.GetInHeight_()); + features[6 * n + 6] = float(problem.GetInWidth_()); + features[7 * n + 7] = float(problem.GetBatchSize_()); return features; } @@ -550,61 +551,61 @@ bool ConvAsm1x1U::IsApplicable(const ConvolutionContext& ctx, return false; } - if(name == "gfx90a" && problem.conv_problem.IsGfx90aFp16altRequired()) + if(name == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; const auto elements_in_dword = 4 / GetTypeSize(problem.GetInDataType()); if(elements_in_dword == 0) // For clang-tidy (false positive DIV/0) MIOPEN_THROW(miopenStatusInternalError); // clang-format off - const int img_hw = problem.GetOutHeight() * problem.GetOutWidth(); + const int img_hw = problem.GetOutHeight_() * problem.GetOutWidth_(); bool ok = (problem.GetPadW() == 0 // -q pad_w && problem.GetPadH() == 0 // -p pad_h && problem.GetKernelStrideW() <= 2 // -u stride_w && problem.GetKernelStrideW() == problem.GetKernelStrideH() - && problem.GetWeightsWidth() == 1 // -x S wei_w - && problem.GetWeightsHeight() == 1 // -y R wei_h + && problem.GetWeightsWidth_() == 1 // -x S wei_w + && problem.GetWeightsHeight_() == 1 // -y R wei_h && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 && problem.GetBias() == 0 - && problem.GetInChannels() % elements_in_dword == 0 - && problem.GetOutChannels() % elements_in_dword == 0 + && problem.GetInChannels_() % elements_in_dword == 0 + && problem.GetOutChannels_() % elements_in_dword == 0 && problem.GetInLayout() == "NCHW" && problem.GetGroupCount() == 1 && img_hw >= elements_in_dword - && (elements_in_dword == 1 || problem.GetOutChannels() >= 4)); + && (elements_in_dword == 1 || problem.GetOutChannels_() >= 4)); if(problem.direction.IsBackwardData() && elements_in_dword != 1) - ok = ok && (problem.GetOutChannels() % 4 == 0); + ok = ok && (problem.GetOutChannels_() % 4 == 0); if(!ok) { return false; // Early exit to speed up the check. } /// \todo Ilya: The checks below look adequate but needs to be double-checked. { - const int64_t input_line_size = 4 * static_cast(problem.GetInWidth()); - const int64_t input_feature_map_size = input_line_size * problem.GetInHeight(); - const int64_t input_stack_size = input_feature_map_size * problem.GetInChannels(); + const int64_t input_line_size = 4 * static_cast(problem.GetInWidth_()); + const int64_t input_feature_map_size = input_line_size * problem.GetInHeight_(); + const int64_t input_stack_size = input_feature_map_size * problem.GetInChannels_(); if (! (input_stack_size < (1U << 24))) return false; } { - const int64_t output_line_size = 4 * static_cast(problem.GetOutWidth()); - const int64_t output_feature_map_size = output_line_size * problem.GetOutHeight(); - const int64_t output_stack_size = output_feature_map_size * problem.GetOutChannels(); + const int64_t output_line_size = 4 * static_cast(problem.GetOutWidth_()); + const int64_t output_feature_map_size = output_line_size * problem.GetOutHeight_(); + const int64_t output_stack_size = output_feature_map_size * problem.GetOutChannels_(); if (! (output_stack_size < (1U << 24))) return false; } // Check limits: auto h_w = static_cast(AsmImgHeight(problem)) * AsmImgWidth(problem); - const auto r_s = static_cast(problem.GetWeightsHeight()) * problem.GetWeightsWidth(); - const auto c_h_w = static_cast(problem.GetInChannels()) * h_w; // C*H*W - const auto k_h_w = static_cast(problem.GetOutChannels()) * h_w; // K*H*W - const auto n_c_h_w = static_cast(problem.GetBatchSize()) * c_h_w; // N*C*H*W - const auto n_k_h_w = static_cast(problem.GetBatchSize()) * k_h_w; // N*K*H*W - const auto c_k_r_s = static_cast(problem.GetInChannels()) * problem.GetOutChannels() * r_s; // C*K*R*S - ok = problem.GetBatchSize() < std::pow(2, 16) // -n N batch_size - && problem.GetInChannels() < std::pow(2, 16) // -c C input_channels - && problem.GetOutChannels() < std::pow(2, 16) // -k K output_channels + const auto r_s = static_cast(problem.GetWeightsHeight_()) * problem.GetWeightsWidth_(); + const auto c_h_w = static_cast(problem.GetInChannels_()) * h_w; // C*H*W + const auto k_h_w = static_cast(problem.GetOutChannels_()) * h_w; // K*H*W + const auto n_c_h_w = static_cast(problem.GetBatchSize_()) * c_h_w; // N*C*H*W + const auto n_k_h_w = static_cast(problem.GetBatchSize_()) * k_h_w; // N*K*H*W + const auto c_k_r_s = static_cast(problem.GetInChannels_()) * problem.GetOutChannels_() * r_s; // C*K*R*S + ok = problem.GetBatchSize_() < std::pow(2, 16) // -n N batch_size + && problem.GetInChannels_() < std::pow(2, 16) // -c C input_channels + && problem.GetOutChannels_() < std::pow(2, 16) // -k K output_channels && c_h_w < std::pow(2, 24) && k_h_w < std::pow(2, 24) && n_c_h_w < std::pow(2, 29) @@ -620,9 +621,9 @@ size_t ConvAsm1x1U::GetWorkspaceSize(const ConvolutionContext&, { int in_batch_stride = AsmImgWidth(problem) * AsmImgHeight(problem) * - (UseSubsample(problem) ? problem.GetInChannels() : problem.GetOutChannels()); + (UseSubsample(problem) ? problem.GetInChannels_() : problem.GetOutChannels_()); int data_len = GetTypeSize(problem.GetOutDataType()); - return static_cast(in_batch_stride) * problem.GetBatchSize() * data_len; + return static_cast(in_batch_stride) * problem.GetBatchSize_() * data_len; } return 0; } @@ -649,7 +650,7 @@ ConvSolution ConvAsm1x1U::GetSolution(const ConvolutionContext& ctx, // subsampled input, in_height equals to image size after downsampling int in_batch_stride = AsmImgWidth(problem) * AsmImgHeight(problem) * - (UseSubsample(problem) ? problem.GetInChannels() : problem.GetOutChannels()); + (UseSubsample(problem) ? problem.GetInChannels_() : problem.GetOutChannels_()); int write_unit = (AsmImgWidth(problem) % 4 == 0) ? 4 : (AsmImgWidth(problem) % 3 == 0) ? 3 : (AsmImgWidth(problem) % 2 == 0) ? 2 @@ -666,14 +667,14 @@ ConvSolution ConvAsm1x1U::GetSolution(const ConvolutionContext& ctx, std::string(" -DMLO_FILTER0_STRIDE0=") + std::to_string(problem.GetKernelStrideW()) + std::string(" -DMLO_FILTER0_STRIDE1=") + std::to_string(problem.GetKernelStrideH()) + std::string(" -DMLO_WRITE_UNIT=") + std::to_string(write_unit) + - std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + std::to_string(problem.GetOutChannelStride()) + - std::string(" -DMLO_OUT_STRIDE=") + std::to_string(problem.GetOutStride()) + + std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + std::to_string(problem.GetOutChannelStride_()) + + std::string(" -DMLO_OUT_STRIDE=") + std::to_string(problem.GetOutStrideH_()) + std::string(" -DMLO_IN_BATCH_STRIDE=") + std::to_string(in_batch_stride) + std::string(" -DMLO_IN0_BATCH_STRIDE=") + - std::to_string(problem.direction.IsForward() ? problem.GetInBatchStride() - : problem.GetOutBatchStride()) + - std::string(" -DMLO_IN0_CHANNEL_STRIDE=") + std::to_string(problem.GetInChannelStride()) + - std::string(" -DMLO_IN0_STRIDE=") + std::to_string(problem.GetInStride()) + + std::to_string(problem.direction.IsForward() ? problem.GetInBatchStride_() + : problem.GetOutBatchStride_()) + + std::string(" -DMLO_IN0_CHANNEL_STRIDE=") + std::to_string(problem.GetInChannelStride_()) + + std::string(" -DMLO_IN0_STRIDE=") + std::to_string(problem.GetInStrideH_()) + ctx.general_compile_options; // clang-format on @@ -682,7 +683,7 @@ ConvSolution ConvAsm1x1U::GetSolution(const ConvolutionContext& ctx, ss_us_kernel.l_wk.push_back(1); // output is number of subsampled input maps size_t gbl_wk0 = (in_batch_stride / write_unit); - size_t gbl_wk1 = problem.GetBatchSize(); + size_t gbl_wk1 = problem.GetBatchSize_(); size_t gbl_wk2 = 1; ss_us_kernel.g_wk.push_back(gbl_wk0); @@ -706,11 +707,11 @@ ConvSolution ConvAsm1x1U::GetSolution(const ConvolutionContext& ctx, GenerateClangDefsym(options, "img_w", AsmImgWidth(problem)); // W // Note that problem.n_outputs and problem.n_inputs are swapped for backward convolutions. - GenerateClangDefsym(options, "batch_size", problem.GetBatchSize()); // N - GenerateClangDefsym(options, "input_channels", problem.GetInChannels()); // C - GenerateClangDefsym(options, "output_channels", problem.GetOutChannels()); // K - GenerateClangDefsym(options, "wei_h", problem.GetWeightsHeight()); // R - GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth()); // S + GenerateClangDefsym(options, "batch_size", problem.GetBatchSize_()); // N + GenerateClangDefsym(options, "input_channels", problem.GetInChannels_()); // C + GenerateClangDefsym(options, "output_channels", problem.GetOutChannels_()); // K + GenerateClangDefsym(options, "wei_h", problem.GetWeightsHeight_()); // R + GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth_()); // S GenerateClangDefsym(options, "pad_h", problem.GetPadH()); GenerateClangDefsym(options, "pad_w", problem.GetPadW()); GenerateClangDefsym(options, "weights_layout", problem.direction.IsForward() ? 0 : 1); @@ -773,24 +774,24 @@ ConvSolution ConvAsm1x1U::GetSolution(const ConvolutionContext& ctx, // cppcheck-suppress unreadVariable buff_info ibuf(MemLayout::NCHW, - problem.GetBatchSize(), - problem.GetInChannels(), + problem.GetBatchSize_(), + problem.GetInChannels_(), AsmImgHeight(problem), AsmImgWidth(problem), 1, data_len); // cppcheck-suppress unreadVariable buff_info obuf(MemLayout::NCHW, - problem.GetBatchSize(), - problem.GetOutChannels(), + problem.GetBatchSize_(), + problem.GetOutChannels_(), AsmImgHeight(problem), AsmImgWidth(problem), 1, data_len); // cppcheck-suppress unreadVariable buff_info fbuf(problem.direction.IsForward() ? MemLayout::NCHW : MemLayout::CNHW, - problem.GetOutChannels(), - problem.GetInChannels(), + problem.GetOutChannels_(), + problem.GetInChannels_(), 1, 1, 1, @@ -867,10 +868,10 @@ ConvSolution ConvAsm1x1U::GetSolution(const ConvolutionContext& ctx, main_kernel.l_wk[0] * divide_round_plus_inf(AsmImgHeight(problem) * AsmImgWidth(problem), hw_per_wave)); - main_kernel.g_wk.push_back(divide_round_plus_inf(problem.GetOutChannels(), + main_kernel.g_wk.push_back(divide_round_plus_inf(problem.GetOutChannels_(), pcfg->GetKMult() * pcfg->GetWavesKInGroup())); const int n_images_per_wave = pcfg->GetNMult() * pcfg->GetNPerGpr(); - main_kernel.g_wk.push_back(divide_round_plus_inf(problem.GetBatchSize(), n_images_per_wave)); + main_kernel.g_wk.push_back(divide_round_plus_inf(problem.GetBatchSize_(), n_images_per_wave)); main_kernel.kernel_file = "conv1x1u.s"; main_kernel.kernel_name = "miopenGcnAsmConv1x1U"; diff --git a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp index 1ba10fad0e..c4ae30f859 100644 --- a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp +++ b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp @@ -142,7 +142,7 @@ ConvBiasActivAsm1x1U::GetSolution(const FusionContext& context, } kernel_info.comp_options += cba_options.str(); - const auto out_data_type = conv_problem.conv_problem.GetOutDataType(); + const auto out_data_type = conv_problem.GetOutDataType(); sol.weight = 50.0f; sol.invoker_factory = [=](const std::vector& kernels) { @@ -247,13 +247,13 @@ bool ConvBiasActivAsm1x1U::IsApplicable(const FusionContext& context, return false; if(conv_problem.GetPadH() != 0) return false; - if(conv_problem.conv_problem.GetKernelStrideH() != conv_problem.conv_problem.GetKernelStrideW()) + if(conv_problem.GetKernelStrideH() != conv_problem.GetKernelStrideW()) return false; - if(conv_problem.conv_problem.GetKernelStrideH() != 1) + if(conv_problem.GetKernelStrideH() != 1) return false; - if(conv_problem.conv_problem.GetDilationH() != conv_problem.conv_problem.GetDilationW()) + if(conv_problem.GetDilationH() != conv_problem.GetDilationW()) return false; - if(conv_problem.conv_problem.GetDilationH() != 1) + if(conv_problem.GetDilationH() != 1) return false; // Check if the conovlution part is applicable diff --git a/src/solver/conv_asm_1x1u_stride2.cpp b/src/solver/conv_asm_1x1u_stride2.cpp index e51966b238..9b3dd0462d 100644 --- a/src/solver/conv_asm_1x1u_stride2.cpp +++ b/src/solver/conv_asm_1x1u_stride2.cpp @@ -131,8 +131,8 @@ struct config_helper dilation_h = problem.GetKernelStrideH(); } - in_strided_w = divide_round_plus_inf(problem.GetInWidth(), stride_w); - in_strided_h = divide_round_plus_inf(problem.GetInHeight(), stride_h); + in_strided_w = divide_round_plus_inf(problem.GetInWidth_(), stride_w); + in_strided_h = divide_round_plus_inf(problem.GetInHeight_(), stride_h); w_per_wave = static_cast(divide_round_plus_inf(config.dwords_per_ld, stride_w) * config.w_mult * (config.chunk_size / config.h_per_chunk)); @@ -334,11 +334,11 @@ bool PerformanceConfigConvAsm1x1UV2::IsValid(const ProblemDescription& problem) return false; if(!(waves_c_in_group * waves_k_in_group <= 16)) return false; - if(!(waves_c_in_group <= problem.GetInChannels())) + if(!(waves_c_in_group <= problem.GetInChannels_())) return false; if(!(h_per_chunk <= chunk_size)) return false; - if(!(k_mult * waves_k_in_group <= problem.GetOutChannels())) + if(!(k_mult * waves_k_in_group <= problem.GetOutChannels_())) return false; // cppcheck-suppress unreadVariable @@ -379,37 +379,37 @@ bool PerformanceConfigConvAsm1x1UV2::IsValid(const ProblemDescription& problem) const auto sgprs = 25 + 2 * k_mult * c_mult; if(!(sgprs < 102)) return false; - const auto total_n_blocks = (problem.GetBatchSize() + GetNPerGpr() - 1) / GetNPerGpr(); + const auto total_n_blocks = (problem.GetBatchSize_() + GetNPerGpr() - 1) / GetNPerGpr(); if(!(n_mult <= total_n_blocks)) return false; - const auto c_per_wave = (problem.GetInChannels() + waves_c_in_group - 1) / waves_c_in_group; - const auto c_per_last_wave = problem.GetInChannels() - (c_per_wave * (waves_c_in_group - 1)); + const auto c_per_wave = (problem.GetInChannels_() + waves_c_in_group - 1) / waves_c_in_group; + const auto c_per_last_wave = problem.GetInChannels_() - (c_per_wave * (waves_c_in_group - 1)); - if(problem.direction.IsBackwardData() && !(problem.GetOutChannels() % k_mult == 0)) + if(problem.direction.IsBackwardData() && !(problem.GetOutChannels_() % k_mult == 0)) return false; { // cppcheck-suppress unreadVariable buff_info ibuf(MemLayout::NCHW, - problem.GetBatchSize(), - problem.GetInChannels(), - problem.GetInHeight(), - problem.GetInWidth(), + problem.GetBatchSize_(), + problem.GetInChannels_(), + problem.GetInHeight_(), + problem.GetInWidth_(), 1, GetTypeSize(problem.GetInDataType())); // cppcheck-suppress unreadVariable buff_info obuf(MemLayout::NCHW, - problem.GetBatchSize(), - problem.GetOutChannels(), - problem.GetOutHeight(), - problem.GetOutWidth(), + problem.GetBatchSize_(), + problem.GetOutChannels_(), + problem.GetOutHeight_(), + problem.GetOutWidth_(), 1, GetTypeSize(problem.GetOutDataType())); int n_miss = n_mult * GetNPerGpr() - 1; - if((static_cast(problem.GetInChannels()) + n_miss) * ibuf.byte_stride.nk >= + if((static_cast(problem.GetInChannels_()) + n_miss) * ibuf.byte_stride.nk >= (1LL << 31) || - (static_cast(problem.GetOutChannels()) + n_miss) * obuf.byte_stride.nk >= + (static_cast(problem.GetOutChannels_()) + n_miss) * obuf.byte_stride.nk >= (1LL << 31)) return false; } @@ -418,8 +418,8 @@ bool PerformanceConfigConvAsm1x1UV2::IsValid(const ProblemDescription& problem) void PerformanceConfigConvAsm1x1UV2::HeuristicInit(const ProblemDescription& problem) { - int c_check = problem.direction.IsForward() ? problem.GetInChannels() : 0; - int k_check = problem.direction.IsForward() ? 0 : problem.GetInChannels(); + int c_check = problem.direction.IsForward() ? problem.GetInChannels_() : 0; + int k_check = problem.direction.IsForward() ? 0 : problem.GetInChannels_(); chunk_size = 16; dwords_per_ld = 1; c_mult = (c_check % 2 == 0) ? 2 : ((c_check % 3 == 0) ? 3 : 1); @@ -512,11 +512,11 @@ bool ConvAsm1x1UV2::IsApplicable(const ConvolutionContext& ctx, const auto elements_in_dword = 4 / GetTypeSize(problem.GetInDataType()); // clang-format off - const auto img_hw = problem.GetOutHeight() * problem.GetOutWidth(); + const auto img_hw = problem.GetOutHeight_() * problem.GetOutWidth_(); bool ok = (problem.GetPadW() == 0 && problem.GetPadH() == 0 - && problem.GetWeightsWidth() == 1 - && problem.GetWeightsHeight() == 1 + && problem.GetWeightsWidth_() == 1 + && problem.GetWeightsHeight_() == 1 && problem.GetKernelStrideW() <= 2 && problem.GetKernelStrideW() == problem.GetKernelStrideH() && problem.GetDilationW() == 1 @@ -537,16 +537,16 @@ bool ConvAsm1x1UV2::IsApplicable(const ConvolutionContext& ctx, } // Check limits: - auto h_w = static_cast(problem.GetInHeight()) * problem.GetInWidth(); - const auto r_s = static_cast(problem.GetWeightsHeight()) * problem.GetWeightsWidth(); - const auto c_h_w = static_cast(problem.GetInChannels()) * h_w; // C*H*W - const auto k_h_w = static_cast(problem.GetOutChannels()) * h_w; // K*H*W - const auto n_c_h_w = static_cast(problem.GetBatchSize()) * c_h_w; // N*C*H*W - const auto n_k_h_w = static_cast(problem.GetBatchSize()) * k_h_w; // N*K*H*W - const auto c_k_r_s = static_cast(problem.GetInChannels()) * problem.GetOutChannels() * r_s; // C*K*R*S - ok = problem.GetBatchSize() < std::pow(2, 16) // -n N batch_size - && problem.GetInChannels() < std::pow(2, 16) // -c C input_channels - && problem.GetOutChannels() < std::pow(2, 16) // -k K output_channels + auto h_w = static_cast(problem.GetInHeight_()) * problem.GetInWidth_(); + const auto r_s = static_cast(problem.GetWeightsHeight_()) * problem.GetWeightsWidth_(); + const auto c_h_w = static_cast(problem.GetInChannels_()) * h_w; // C*H*W + const auto k_h_w = static_cast(problem.GetOutChannels_()) * h_w; // K*H*W + const auto n_c_h_w = static_cast(problem.GetBatchSize_()) * c_h_w; // N*C*H*W + const auto n_k_h_w = static_cast(problem.GetBatchSize_()) * k_h_w; // N*K*H*W + const auto c_k_r_s = static_cast(problem.GetInChannels_()) * problem.GetOutChannels_() * r_s; // C*K*R*S + ok = problem.GetBatchSize_() < std::pow(2, 16) // -n N batch_size + && problem.GetInChannels_() < std::pow(2, 16) // -c C input_channels + && problem.GetOutChannels_() < std::pow(2, 16) // -k K output_channels && c_h_w < std::pow(2, 24) && k_h_w < std::pow(2, 24) && n_c_h_w < std::pow(2, 29) @@ -560,18 +560,18 @@ bool ConvAsm1x1UV2::IsApplicable(const ConvolutionContext& ctx, const auto& config = problem; // alias // cppcheck-suppress unreadVariable buff_info ibuf(MemLayout::NCHW, - config.GetBatchSize(), - config.GetInChannels(), - config.GetInHeight(), - config.GetInWidth(), + config.GetBatchSize_(), + config.GetInChannels_(), + config.GetInHeight_(), + config.GetInWidth_(), 1, GetTypeSize(config.GetInDataType())); // cppcheck-suppress unreadVariable buff_info obuf(MemLayout::NCHW, - config.GetBatchSize(), - config.GetOutChannels(), - config.GetOutHeight(), - config.GetOutWidth(), + config.GetBatchSize_(), + config.GetOutChannels_(), + config.GetOutHeight_(), + config.GetOutWidth_(), 1, GetTypeSize(config.GetOutDataType())); @@ -579,9 +579,9 @@ bool ConvAsm1x1UV2::IsApplicable(const ConvolutionContext& ctx, const int eurictic_init_max_chunk_size = 16; const int n_miss = eurictic_init_min_n_mult * (64 / eurictic_init_max_chunk_size) - 1; - if((static_cast(config.GetInChannels()) + n_miss) * ibuf.byte_stride.nk >= + if((static_cast(config.GetInChannels_()) + n_miss) * ibuf.byte_stride.nk >= (1LL << 31) || - (static_cast(config.GetOutChannels()) + n_miss) * obuf.byte_stride.nk >= + (static_cast(config.GetOutChannels_()) + n_miss) * obuf.byte_stride.nk >= (1LL << 31)) ok = false; } @@ -633,18 +633,18 @@ ConvSolution ConvAsm1x1UV2::GetSolution(const ConvolutionContext& ctx, GenerateClangDefsym(options, "idilation_h", uv_lj.dilation_h); GenerateClangDefsym(options, "idilation_w", uv_lj.dilation_w); - GenerateClangDefsym(options, "img_h", problem.GetInHeight()); // H - GenerateClangDefsym(options, "img_w", problem.GetInWidth()); // W + GenerateClangDefsym(options, "img_h", problem.GetInHeight_()); // H + GenerateClangDefsym(options, "img_w", problem.GetInWidth_()); // W - GenerateClangDefsym(options, "out_h", problem.GetOutHeight()); // H - GenerateClangDefsym(options, "out_w", problem.GetOutWidth()); // W + GenerateClangDefsym(options, "out_h", problem.GetOutHeight_()); // H + GenerateClangDefsym(options, "out_w", problem.GetOutWidth_()); // W // Note that problem.n_outputs and problem.n_inputs are swapped for backward convolutions. - GenerateClangDefsym(options, "batch_size", problem.GetBatchSize()); // N - GenerateClangDefsym(options, "input_channels", problem.GetInChannels()); // C - GenerateClangDefsym(options, "output_channels", problem.GetOutChannels()); // K - GenerateClangDefsym(options, "wei_h", problem.GetWeightsHeight()); // R - GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth()); // S + GenerateClangDefsym(options, "batch_size", problem.GetBatchSize_()); // N + GenerateClangDefsym(options, "input_channels", problem.GetInChannels_()); // C + GenerateClangDefsym(options, "output_channels", problem.GetOutChannels_()); // K + GenerateClangDefsym(options, "wei_h", problem.GetWeightsHeight_()); // R + GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth_()); // S GenerateClangDefsym(options, "pad_h", problem.GetPadH()); GenerateClangDefsym(options, "pad_w", problem.GetPadW()); GenerateClangDefsym(options, "weights_layout", problem.direction.IsForward() ? 0 : 1); @@ -658,24 +658,24 @@ ConvSolution ConvAsm1x1UV2::GetSolution(const ConvolutionContext& ctx, // cppcheck-suppress unreadVariable buff_info ibuf(MemLayout::NCHW, - problem.GetBatchSize(), - problem.GetInChannels(), - problem.GetInHeight(), - problem.GetInWidth(), + problem.GetBatchSize_(), + problem.GetInChannels_(), + problem.GetInHeight_(), + problem.GetInWidth_(), 1, data_len); // cppcheck-suppress unreadVariable buff_info obuf(MemLayout::NCHW, - problem.GetBatchSize(), - problem.GetOutChannels(), - problem.GetOutHeight(), - problem.GetOutWidth(), + problem.GetBatchSize_(), + problem.GetOutChannels_(), + problem.GetOutHeight_(), + problem.GetOutWidth_(), 1, data_len); // cppcheck-suppress unreadVariable buff_info fbuf(problem.direction.IsForward() ? MemLayout::NCHW : MemLayout::CNHW, - problem.GetOutChannels(), - problem.GetInChannels(), + problem.GetOutChannels_(), + problem.GetInChannels_(), 1, 1, 1, @@ -730,9 +730,9 @@ ConvSolution ConvAsm1x1UV2::GetSolution(const ConvolutionContext& ctx, divide_round_plus_inf(uv_lj.in_strided_w, uv_lj.w_per_wave) * divide_round_plus_inf(uv_lj.in_strided_h, uv_lj.h_per_wave)); - kinfo.g_wk.push_back(divide_round_plus_inf(problem.GetOutChannels(), k_per_wave)); + kinfo.g_wk.push_back(divide_round_plus_inf(problem.GetOutChannels_(), k_per_wave)); - kinfo.g_wk.push_back(divide_round_plus_inf(problem.GetBatchSize(), n_per_wave)); + kinfo.g_wk.push_back(divide_round_plus_inf(problem.GetBatchSize_(), n_per_wave)); kinfo.kernel_file = "conv1x1u_stride2.s"; kinfo.kernel_name = "miopenGcnAsmConv1x1U_stride2"; diff --git a/src/solver/conv_asm_3x3u.cpp b/src/solver/conv_asm_3x3u.cpp index 5385d8d198..1ebb39c84f 100644 --- a/src/solver/conv_asm_3x3u.cpp +++ b/src/solver/conv_asm_3x3u.cpp @@ -82,15 +82,15 @@ bool PerformanceConfigConvAsm3x3U::IsValid(const ProblemDescription& problem) co if(!IsValidValue()) return false; // to-do add support of uneven_outputs into grouped conv - bool uneven_outputs = (problem.GetOutChannels() % filters_per_wave) != 0; - auto num_wavefronts = problem.GetOutChannels() / filters_per_wave; + bool uneven_outputs = (problem.GetOutChannels_() % filters_per_wave) != 0; + auto num_wavefronts = problem.GetOutChannels_() / filters_per_wave; if(problem.GetGroupCount() > 1 && (uneven_outputs || (num_wavefronts % problem.GetGroupCount() != 0))) return false; // Count the number of VGPRs required. - const auto img_width = problem.GetInWidth(); - const auto img_height = problem.GetInHeight(); + const auto img_width = problem.GetInWidth_(); + const auto img_height = problem.GetInHeight_(); int n = 0; const bool enable_zero_line_padding_on_read = (img_height != output_lines_per_wave); @@ -115,7 +115,7 @@ bool PerformanceConfigConvAsm3x3U::IsValid(const ProblemDescription& problem) co const int input_lines_per_wave = (img_height == output_lines_per_wave) ? output_lines_per_wave : (output_lines_per_wave + 2); - const int k_group_size = problem.GetOutChannels() / problem.GetGroupCount(); + const int k_group_size = problem.GetOutChannels_() / problem.GetGroupCount(); const bool k_group_size_is_power_of_two = ((k_group_size & (k_group_size - 1)) == 0); n += (k_group_size_is_power_of_two || gprs_per_input_line * input_lines_per_wave >= 4) ? (gprs_per_input_line * input_lines_per_wave) @@ -141,7 +141,7 @@ void PerformanceConfigConvAsm3x3U::HeuristicInit(const ProblemDescription& probl filters_per_wave = 2; output_lines_per_wave = 2; - if(problem.GetOutChannels() % (filters_per_wave * problem.GetGroupCount()) != 0) + if(problem.GetOutChannels_() % (filters_per_wave * problem.GetGroupCount()) != 0) { filters_per_wave = 1; } @@ -201,15 +201,15 @@ bool ConvAsm3x3U::IsApplicable(const ConvolutionContext& ctx, constexpr auto ELEM_SZ = static_cast(sizeof(float)); constexpr int64_t SHADER_FEATURE_INDEX_MAX = static_cast(-1); const auto IN_FEATURE_COUNT = - static_cast(problem.GetBatchSize()) * problem.GetInChannels(); + static_cast(problem.GetBatchSize_()) * problem.GetInChannels_(); const auto OUT_FEATURE_COUNT = - static_cast(problem.GetBatchSize()) * problem.GetOutChannels(); - const auto IN_IMG_SZ = ELEM_SZ * problem.GetInHeight() * problem.GetInWidth(); - const auto OUT_IMG_SZ = ELEM_SZ * problem.GetOutHeight() * problem.GetOutWidth(); + static_cast(problem.GetBatchSize_()) * problem.GetOutChannels_(); + const auto IN_IMG_SZ = ELEM_SZ * problem.GetInHeight_() * problem.GetInWidth_(); + const auto OUT_IMG_SZ = ELEM_SZ * problem.GetOutHeight_() * problem.GetOutWidth_(); const auto IN_BUF_SZ = IN_IMG_SZ * IN_FEATURE_COUNT; const auto OUT_BUF_SZ = OUT_IMG_SZ * OUT_FEATURE_COUNT; - const auto WEI_BUF_SZ = ELEM_SZ * problem.GetInChannels() * problem.GetOutChannels() * - problem.GetWeightsHeight() * problem.GetWeightsWidth(); + const auto WEI_BUF_SZ = ELEM_SZ * problem.GetInChannels_() * problem.GetOutChannels_() * + problem.GetWeightsHeight_() * problem.GetWeightsWidth_(); // clang-format off return problem.GetPadW() == 1 && problem.GetPadH() == 1 @@ -217,12 +217,12 @@ bool ConvAsm3x3U::IsApplicable(const ConvolutionContext& ctx, && problem.GetKernelStrideH() == 1 && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 - && problem.GetWeightsWidth() == 3 - && problem.GetWeightsHeight() == 3 - && problem.GetInChannels() > 0 - && (problem.GetInChannels() / problem.GetGroupCount()) % 4 == 0 /// \todo: remove restriction that (n_inputs/group_counts) must be multiple of 4 - && problem.GetInWidth() > 3 - && problem.GetInWidth() <= 1000 + && problem.GetWeightsWidth_() == 3 + && problem.GetWeightsHeight_() == 3 + && problem.GetInChannels_() > 0 + && (problem.GetInChannels_() / problem.GetGroupCount()) % 4 == 0 /// \todo: remove restriction that (n_inputs/group_counts) must be multiple of 4 + && problem.GetInWidth_() > 3 + && problem.GetInWidth_() <= 1000 && IN_IMG_SZ <= GIB && OUT_IMG_SZ <= 4 * GIB && IN_FEATURE_COUNT - 1 <= SHADER_FEATURE_INDEX_MAX @@ -268,18 +268,18 @@ ConvSolution ConvAsm3x3U::GetSolution(const ConvolutionContext& ctx, } } - const int k_group_size = problem.GetOutChannels() / problem.GetGroupCount(); + const int k_group_size = problem.GetOutChannels_() / problem.GetGroupCount(); const bool k_group_size_is_power_of_two = ((k_group_size & (k_group_size - 1)) == 0); - const auto w64_chunks = (problem.GetInWidth() + 63) / 64; - const auto active_lanes = (problem.GetInWidth() + w64_chunks - 1) / w64_chunks; + const auto w64_chunks = (problem.GetInWidth_() + 63) / 64; + const auto active_lanes = (problem.GetInWidth_() + w64_chunks - 1) / w64_chunks; KernelBuildParameters options{ - {"batch_size", problem.GetBatchSize()}, - {"img_width", problem.GetInWidth()}, - {"img_height", problem.GetInHeight()}, - {"input_channels", problem.GetInChannels()}, - {"output_channels", problem.GetOutChannels()}, + {"batch_size", problem.GetBatchSize_()}, + {"img_width", problem.GetInWidth_()}, + {"img_height", problem.GetInHeight_()}, + {"input_channels", problem.GetInChannels_()}, + {"output_channels", problem.GetOutChannels_()}, {"weights_layout", problem.direction.IsForward() ? 0 : 1}, {"reverse_weights", problem.direction.IsForward() ? 0 : 1}, {"ROCM_METADATA_VERSION", ctx.rmv.UseV3() ? 5 : 4}, @@ -301,10 +301,10 @@ ConvSolution ConvAsm3x3U::GetSolution(const ConvolutionContext& ctx, construction_params.g_wk.push_back(static_cast( active_lanes * - ((problem.GetOutChannels() + pcfg->filters_per_wave - 1) / pcfg->filters_per_wave))); - construction_params.g_wk.push_back((problem.GetInHeight() + pcfg->output_lines_per_wave - 1) / + ((problem.GetOutChannels_() + pcfg->filters_per_wave - 1) / pcfg->filters_per_wave))); + construction_params.g_wk.push_back((problem.GetInHeight_() + pcfg->output_lines_per_wave - 1) / pcfg->output_lines_per_wave); - construction_params.g_wk.push_back(problem.GetBatchSize()); + construction_params.g_wk.push_back(problem.GetBatchSize_()); construction_params.kernel_file = "conv3x3.s"; construction_params.kernel_name = "miopenGcnAsmConv3x3U"; diff --git a/src/solver/conv_asm_5x10u2v2b1.cpp b/src/solver/conv_asm_5x10u2v2b1.cpp index 352d1ab842..fe0f0f42a3 100644 --- a/src/solver/conv_asm_5x10u2v2b1.cpp +++ b/src/solver/conv_asm_5x10u2v2b1.cpp @@ -83,20 +83,20 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx, const int max_out_height = 131077 - 1; // clang-format off - return problem.GetPadW() == 0 // -q pad_w fixed - && problem.GetPadH() == 0 // -p pad_h fixed - && problem.GetKernelStrideW() == 2 // -v inp_v fixed - && problem.GetKernelStrideH() == 2 // -u inp_u fixed - && problem.GetWeightsWidth() == 10 // -x wei_w fixed - && problem.GetWeightsHeight() == 5 // -y wei_h fixed + return problem.GetPadW() == 0 // -q pad_w fixed + && problem.GetPadH() == 0 // -p pad_h fixed + && problem.GetKernelStrideW() == 2 // -v inp_v fixed + && problem.GetKernelStrideH() == 2 // -u inp_u fixed + && problem.GetWeightsWidth_() == 10 // -x wei_w fixed + && problem.GetWeightsHeight_() == 5 // -y wei_h fixed && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 - && problem.GetOutChannels() % 16 == 0 // -c wei_c no upper limit - && problem.GetInChannels() >= 16 // -k wei_k no upper limit - && problem.GetOutWidth() >= min_out_width // -W inp_w - && problem.GetOutWidth() <= max_out_width - && problem.GetOutHeight() >= min_out_height // -H inp_h - && problem.GetOutHeight() <= max_out_height + && problem.GetOutChannels_() % 16 == 0 // -c wei_c no upper limit + && problem.GetInChannels_() >= 16 // -k wei_k no upper limit + && problem.GetOutWidth_() >= min_out_width // -W inp_w + && problem.GetOutWidth_() <= max_out_width + && problem.GetOutHeight_() >= min_out_height // -H inp_h + && problem.GetOutHeight_() <= max_out_height && problem.IsFp32() && problem.GetGroupCount() == 1 && problem.GetOutLayout() == "NCHW"; // hardcoded @@ -109,10 +109,10 @@ ConvSolution ConvAsm5x10u2v2b1::GetSolution(const ExecutionContext& ctx, { ConvSolution result; std::ostringstream options; - GenerateClangDefsym(options, "inp_h", problem.GetOutHeight()); - GenerateClangDefsym(options, "inp_w", problem.GetOutWidth()); - GenerateClangDefsym(options, "wei_c", problem.GetOutChannels()); - GenerateClangDefsym(options, "wei_k", problem.GetInChannels()); + GenerateClangDefsym(options, "inp_h", problem.GetOutHeight_()); + GenerateClangDefsym(options, "inp_w", problem.GetOutWidth_()); + GenerateClangDefsym(options, "wei_c", problem.GetOutChannels_()); + GenerateClangDefsym(options, "wei_k", problem.GetInChannels_()); GenerateClangDefsym(options, "ROCM_METADATA_VERSION", ctx.rmv.UseV3() ? 5 : 4); KernelInfo constr_params; @@ -123,10 +123,10 @@ ConvSolution ConvAsm5x10u2v2b1::GetSolution(const ExecutionContext& ctx, constr_params.l_wk.push_back(1); // global-work = [align(out_w,64), (align(out_h,4)/4)*align(wei_c/2,8), batch_n] - constr_params.g_wk.push_back(AlignUp(problem.GetInWidth(), 64)); - constr_params.g_wk.push_back(static_cast(AlignUp(problem.GetInHeight(), 4) / 4 * - AlignUp(problem.GetOutChannels() / 2, 8))); - constr_params.g_wk.push_back(problem.GetBatchSize()); + constr_params.g_wk.push_back(AlignUp(problem.GetInWidth_(), 64)); + constr_params.g_wk.push_back(static_cast(AlignUp(problem.GetInHeight_(), 4) / 4 * + AlignUp(problem.GetOutChannels_() / 2, 8))); + constr_params.g_wk.push_back(problem.GetBatchSize_()); constr_params.kernel_file = "conv5x10u2v2b1.s"; constr_params.kernel_name = "miopenConv5x10u2v2b1"; diff --git a/src/solver/conv_asm_5x10u2v2f1.cpp b/src/solver/conv_asm_5x10u2v2f1.cpp index 68cbc5bcd0..09e2d2abed 100644 --- a/src/solver/conv_asm_5x10u2v2f1.cpp +++ b/src/solver/conv_asm_5x10u2v2f1.cpp @@ -77,8 +77,8 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, } // Min image + padding shall be not smaller than filter matrix. - const int min_in_width = problem.GetWeightsWidth() - problem.GetPadW() * 2; - const int min_in_height = problem.GetWeightsHeight() - problem.GetPadH() * 2; + const int min_in_width = static_cast(problem.GetWeightsWidth_()) - problem.GetPadW() * 2; + const int min_in_height = static_cast(problem.GetWeightsHeight_()) - problem.GetPadH() * 2; // These two found experimentally. const int max_in_width = 8192 - 1; const int max_in_height = 131077 - 1; @@ -86,19 +86,19 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, // clang-format off return 0 <= problem.GetPadW() && problem.GetPadW() <= 5 // -q pad_w // [0..5] for now FIXME && 0 <= problem.GetPadH() && problem.GetPadH() <= 5 // -p pad_h // [0..5] for now FIXME - && problem.GetKernelStrideW() == 2 // -v inp_v fixed - && problem.GetKernelStrideH() == 2 // -u inp_u fixed - && problem.GetWeightsWidth() == 10 // -x wei_w fixed - && problem.GetWeightsHeight() == 5 // -y wei_h fixed + && problem.GetKernelStrideW() == 2 // -v inp_v fixed + && problem.GetKernelStrideH() == 2 // -u inp_u fixed + && problem.GetWeightsWidth_() == 10 // -x wei_w fixed + && problem.GetWeightsHeight_() == 5 // -y wei_h fixed && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 - && problem.GetInChannels() >= 1 // -c wei_c no upper limit - && problem.GetOutChannels() % 16 == 0 // -k wei_k no upper limit - && problem.GetOutChannels() >= 1 - && problem.GetInWidth() >= min_in_width // -W inp_w - && problem.GetInWidth() <= max_in_width - && problem.GetInHeight() >= min_in_height // -H inp_h - && problem.GetInHeight() <= max_in_height + && problem.GetInChannels_() >= 1 // -c wei_c no upper limit + && problem.GetOutChannels_() % 16 == 0 // -k wei_k no upper limit + && problem.GetOutChannels_() >= 1 + && static_cast(problem.GetInWidth_()) >= min_in_width // -W inp_w + && problem.GetInWidth_() <= max_in_width + && static_cast(problem.GetInHeight_()) >= min_in_height // -H inp_h + && problem.GetInHeight_() <= max_in_height && problem.IsFp32() && problem.GetGroupCount() == 1 && problem.GetInLayout() == "NCHW"; // hardcoded @@ -115,18 +115,18 @@ ConvSolution ConvAsm5x10u2v2f1::GetSolution(const ExecutionContext& ctx, const ProblemDescription& problem) const { ConvSolution result; - const int out_w = (problem.GetInWidth() + problem.GetPadW() * 2 + problem.GetKernelStrideW() - - problem.GetWeightsWidth()) / + const int out_w = (static_cast(problem.GetInWidth_()) + problem.GetPadW() * 2 + + problem.GetKernelStrideW() - static_cast(problem.GetWeightsWidth_())) / problem.GetKernelStrideW(); // (inp_w + 2*pad_w + inp_v - wei_w) / inp_v - const int out_h = (problem.GetInHeight() + problem.GetPadH() * 2 + problem.GetKernelStrideH() - - problem.GetWeightsHeight()) / + const int out_h = (static_cast(problem.GetInHeight_()) + problem.GetPadH() * 2 + + problem.GetKernelStrideH() - static_cast(problem.GetWeightsHeight_())) / problem.GetKernelStrideH(); // (inp_h + 2*pad_h + inp_u - wei_h) / inp_u std::ostringstream options; - GenerateClangDefsym(options, "inp_h", problem.GetInHeight()); - GenerateClangDefsym(options, "inp_w", problem.GetInWidth()); - GenerateClangDefsym(options, "wei_c", problem.GetInChannels()); - GenerateClangDefsym(options, "wei_k", problem.GetOutChannels()); + GenerateClangDefsym(options, "inp_h", problem.GetInHeight_()); + GenerateClangDefsym(options, "inp_w", problem.GetInWidth_()); + GenerateClangDefsym(options, "wei_c", problem.GetInChannels_()); + GenerateClangDefsym(options, "wei_k", problem.GetOutChannels_()); GenerateClangDefsym(options, "wei_layout", 0); // 0: KCHW, 1: CKHW GenerateClangDefsym(options, "pad_w", problem.GetPadW()); GenerateClangDefsym(options, "pad_h", problem.GetPadH()); @@ -142,8 +142,8 @@ ConvSolution ConvAsm5x10u2v2f1::GetSolution(const ExecutionContext& ctx, // global-work = [align(out_w,64), (align(out_h,4)/4)*align(wei_k/2,8), batch_n] construction_params.g_wk.push_back(AlignUp(out_w, 64)); construction_params.g_wk.push_back( - static_cast(AlignUp(out_h, 4) / 4 * AlignUp(problem.GetOutChannels() / 2, 8))); - construction_params.g_wk.push_back(problem.GetBatchSize()); + static_cast(AlignUp(out_h, 4) / 4 * AlignUp(problem.GetOutChannels_() / 2, 8))); + construction_params.g_wk.push_back(problem.GetBatchSize_()); construction_params.kernel_file = "conv5x10u2v2f1.s"; construction_params.kernel_name = "miopenConv5x10u2v2f1"; diff --git a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp index f52e965601..4310a87fd6 100644 --- a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp +++ b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp @@ -75,18 +75,18 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx } // clang-format off - return problem.GetPadW() == 3 // -q - && problem.GetPadH() == 3 // -p - && problem.GetKernelStrideW() == 2 // -v - && problem.GetKernelStrideH() == 2 // -u - && problem.GetWeightsWidth() == 7 // -x - && problem.GetWeightsHeight() == 7 // -y + return problem.GetPadW() == 3 // -q + && problem.GetPadH() == 3 // -p + && problem.GetKernelStrideW() == 2 // -v + && problem.GetKernelStrideH() == 2 // -u + && problem.GetWeightsWidth_() == 7 // -x + && problem.GetWeightsHeight_() == 7 // -y && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 - && problem.GetInChannels() == 3 // -c - && problem.GetOutChannels() == 64 // -k - && problem.GetInWidth() == 224 // -W - && problem.GetInHeight() == 224 // -H + && problem.GetInChannels_() == 3 // -c + && problem.GetOutChannels_() == 64 // -k + && problem.GetInWidth_() == 224 // -W + && problem.GetInHeight_() == 224 // -H && problem.IsFp32() && problem.GetGroupCount() == 1 && problem.GetInLayout() == "NCHW"; @@ -98,11 +98,11 @@ ConvSolution ConvAsm7x7c3h224w224k64u2v2p3q3f1::GetSolution(const ExecutionConte const ProblemDescription& problem) const { ConvSolution result; - const int out_w = (problem.GetInWidth() + problem.GetPadW() * 2 + problem.GetKernelStrideW() - - problem.GetWeightsWidth()) / + const int out_w = (static_cast(problem.GetInWidth_()) + problem.GetPadW() * 2 + + problem.GetKernelStrideW() - static_cast(problem.GetWeightsWidth_())) / problem.GetKernelStrideW(); // (inp_w + 2*pad_w + inp_v - wei_w) / inp_v - const int out_h = (problem.GetInHeight() + problem.GetPadH() * 2 + problem.GetKernelStrideH() - - problem.GetWeightsHeight()) / + const int out_h = (static_cast(problem.GetInHeight_()) + problem.GetPadH() * 2 + + problem.GetKernelStrideH() - static_cast(problem.GetWeightsHeight_())) / problem.GetKernelStrideH(); // (inp_h + 2*pad_h + inp_u - wei_h) / inp_u std::ostringstream options; @@ -117,8 +117,8 @@ ConvSolution ConvAsm7x7c3h224w224k64u2v2p3q3f1::GetSolution(const ExecutionConte // global-work = [align(out_w,64), (align(out_h,4)/4)*align(wei_k/2,8), batch_n] constr_params.g_wk.push_back(AlignUp(out_w, 64)); constr_params.g_wk.push_back( - static_cast(AlignUp(out_h, 4) / 4 * AlignUp(problem.GetOutChannels() / 2, 8))); - constr_params.g_wk.push_back(problem.GetBatchSize()); + static_cast(AlignUp(out_h, 4) / 4 * AlignUp(problem.GetOutChannels_() / 2, 8))); + constr_params.g_wk.push_back(problem.GetBatchSize_()); constr_params.kernel_file = "conv7x7c3h224w224k64u2v2p3q3f1.s"; constr_params.kernel_name = "miopenGcnAsmConv7x7c3h224w224k64u2v2p3q3f1"; diff --git a/src/solver/conv_asm_dir_BwdWrW1x1.cpp b/src/solver/conv_asm_dir_BwdWrW1x1.cpp index bdfd2876fe..4cd78b7357 100644 --- a/src/solver/conv_asm_dir_BwdWrW1x1.cpp +++ b/src/solver/conv_asm_dir_BwdWrW1x1.cpp @@ -55,12 +55,12 @@ static inline bool UseSubsample(const ProblemDescription& problem) /// out_height/out_width and vice versa. static inline int AsmImgHeight(const ProblemDescription& problem) { - return UseSubsample(problem) ? problem.GetInHeight() : problem.GetOutHeight(); + return UseSubsample(problem) ? problem.GetInHeight_() : problem.GetOutHeight_(); } static inline int AsmImgWidth(const ProblemDescription& problem) { - return UseSubsample(problem) ? problem.GetInWidth() : problem.GetOutWidth(); + return UseSubsample(problem) ? problem.GetInWidth_() : problem.GetOutWidth_(); } inline static bool Inc_1_2_4_8_16(int& v) @@ -327,7 +327,7 @@ bool PerformanceConfigConvAsmBwdWrW1x1::IsValid(const ConvolutionContext& ctx, { const int sequential_channels = 2; if((c_mult % sequential_channels) != 0 || - (problem.GetOutChannels() % sequential_channels) != 0) + (problem.GetOutChannels_() % sequential_channels) != 0) return false; } } @@ -370,10 +370,10 @@ void PerformanceConfigConvAsmBwdWrW1x1::HeuristicInit(const ConvolutionContext& : 0; read_size = 4; n_per_gpr = - (problem.GetBatchSize() >= 4 && (AsmImgHeight(problem) * AsmImgWidth(problem)) <= 128) ? 4 - : 1; + (problem.GetBatchSize_() >= 4 && (AsmImgHeight(problem) * AsmImgWidth(problem)) <= 128) ? 4 + : 1; data_prefetch = 1; - const auto c_k_256 = problem.GetOutChannels() * problem.GetInChannels() / 256; // C*K/256 + const auto c_k_256 = problem.GetOutChannels_() * problem.GetInChannels_() / 256; // C*K/256 if(c_k_256 < 2) { c_per_gpr = 1; @@ -498,7 +498,7 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ConvolutionContext& ctx, return false; } - if(name == "gfx90a" && problem.conv_problem.IsGfx90aFp16altRequired()) + if(name == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; // clang-format off @@ -507,8 +507,8 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ConvolutionContext& ctx, && problem.GetKernelStrideW() <= 2 // -v stride_w && problem.GetKernelStrideH() <= 2 // -u stride_h && problem.GetKernelStrideW() == problem.GetKernelStrideH() - && problem.GetWeightsWidth() == 1 // -x S wei_w - && problem.GetWeightsHeight() == 1 // -y R wei_h + && problem.GetWeightsWidth_() == 1 // -x S wei_w + && problem.GetWeightsHeight_() == 1 // -y R wei_h && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 && problem.GetBias() == 0 @@ -521,15 +521,15 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ConvolutionContext& ctx, } // Check limits: const auto h_w = static_cast(AsmImgHeight(problem)) * AsmImgWidth(problem); - const auto r_s = static_cast(problem.GetWeightsHeight()) * problem.GetWeightsWidth(); - const auto c_h_w = static_cast(problem.GetOutChannels()) * h_w; // C*H*W - const auto k_h_w = static_cast(problem.GetInChannels()) * h_w; // K*H*W - const auto n_c_h_w = static_cast(problem.GetBatchSize()) * c_h_w; // N*C*H*W - const auto n_k_h_w = static_cast(problem.GetBatchSize()) * k_h_w; // N*K*H*W - const auto c_k_r_s = static_cast(problem.GetOutChannels()) * problem.GetInChannels() * r_s; // C*K*R*S - ok = problem.GetBatchSize() < std::pow(2, 16) // -n N batch_size - && problem.GetOutChannels() < std::pow(2, 16) // -c C input_channels - && problem.GetInChannels() < std::pow(2, 16) // -k K output_channels + const auto r_s = static_cast(problem.GetWeightsHeight_()) * problem.GetWeightsWidth_(); + const auto c_h_w = static_cast(problem.GetOutChannels_()) * h_w; // C*H*W + const auto k_h_w = static_cast(problem.GetInChannels_()) * h_w; // K*H*W + const auto n_c_h_w = static_cast(problem.GetBatchSize_()) * c_h_w; // N*C*H*W + const auto n_k_h_w = static_cast(problem.GetBatchSize_()) * k_h_w; // N*K*H*W + const auto c_k_r_s = static_cast(problem.GetOutChannels_()) * problem.GetInChannels_() * r_s; // C*K*R*S + ok = problem.GetBatchSize_() < std::pow(2, 16) // -n N batch_size + && problem.GetOutChannels_() < std::pow(2, 16) // -c C input_channels + && problem.GetInChannels_() < std::pow(2, 16) // -k K output_channels && c_h_w < std::pow(2, 24) && k_h_w < std::pow(2, 24) && n_c_h_w < std::pow(2, 29) @@ -553,8 +553,8 @@ size_t ConvAsmBwdWrW1x1::GetWorkspaceSize(const ConvolutionContext&, { int data_len = GetTypeSize(problem.GetOutDataType()); int in_batch_stride = - problem.GetInStride() * problem.GetInHeight() * problem.GetOutChannels(); - return static_cast(in_batch_stride) * problem.GetBatchSize() * data_len; + problem.GetInStrideH_() * problem.GetInHeight_() * problem.GetOutChannels_(); + return static_cast(in_batch_stride) * problem.GetBatchSize_() * data_len; } else return 0; @@ -574,11 +574,11 @@ ConvSolution ConvAsmBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, { // subsampled input, in_height equals to image size after downsampling int in_batch_stride = - problem.GetInStride() * problem.GetInHeight() * problem.GetOutChannels(); - int write_unit = (problem.GetInWidth() % 4 == 0) ? 4 - : (problem.GetInWidth() % 3 == 0) ? 3 - : (problem.GetInWidth() % 2 == 0) ? 2 - : 1; + problem.GetInStrideH_() * problem.GetInHeight_() * problem.GetOutChannels_(); + int write_unit = (problem.GetInWidth_() % 4 == 0) ? 4 + : (problem.GetInWidth_() % 3 == 0) ? 3 + : (problem.GetInWidth_() % 2 == 0) ? 2 + : 1; int n_grp0_size0 = 256; // clang-format off @@ -588,12 +588,12 @@ ConvSolution ConvAsmBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, std::string(" -DMLO_FILTER0_STRIDE0=") + std::to_string(problem.GetKernelStrideW()) + std::string(" -DMLO_FILTER0_STRIDE1=") + std::to_string(problem.GetKernelStrideH()) + std::string(" -DMLO_WRITE_UNIT=") + std::to_string(write_unit) + - std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + std::to_string(problem.GetInChannelStride()) + - std::string(" -DMLO_OUT_STRIDE=") + std::to_string(problem.GetInStride()) + + std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + std::to_string(problem.GetInChannelStride_()) + + std::string(" -DMLO_OUT_STRIDE=") + std::to_string(problem.GetInStrideH_()) + std::string(" -DMLO_IN_BATCH_STRIDE=") + std::to_string(in_batch_stride) + - std::string(" -DMLO_IN0_BATCH_STRIDE=") + std::to_string(problem.GetOutBatchStride()) + - std::string(" -DMLO_IN0_CHANNEL_STRIDE=") + std::to_string(problem.GetOutChannelStride()) + - std::string(" -DMLO_IN0_STRIDE=") + std::to_string(problem.GetOutStride()) + + std::string(" -DMLO_IN0_BATCH_STRIDE=") + std::to_string(problem.GetOutBatchStride_()) + + std::string(" -DMLO_IN0_CHANNEL_STRIDE=") + std::to_string(problem.GetOutChannelStride_()) + + std::string(" -DMLO_IN0_STRIDE=") + std::to_string(problem.GetOutStrideH_()) + ctx.general_compile_options; // clang-format on @@ -604,7 +604,7 @@ ConvSolution ConvAsmBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, kernel.l_wk.push_back(1); // output is number of subsampled input maps size_t gbl_wk0 = (in_batch_stride / write_unit); - size_t gbl_wk1 = problem.GetBatchSize(); + size_t gbl_wk1 = problem.GetBatchSize_(); size_t gbl_wk2 = 1; kernel.g_wk.push_back(gbl_wk0); @@ -627,12 +627,12 @@ ConvSolution ConvAsmBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, GenerateClangDefsym(options, "out_h", AsmImgHeight(problem)); // output H GenerateClangDefsym(options, "out_w", AsmImgWidth(problem)); // output W - GenerateClangDefsym(options, "batch_size", problem.GetBatchSize()); // N + GenerateClangDefsym(options, "batch_size", problem.GetBatchSize_()); // N // Note that problem.n_outputs and problem.n_inputs are swapped for backward convolutions. - GenerateClangDefsym(options, "input_channels", problem.GetOutChannels()); // C - GenerateClangDefsym(options, "output_channels", problem.GetInChannels()); // K - GenerateClangDefsym(options, "wei_h", problem.GetWeightsHeight()); // R - GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth()); // S + GenerateClangDefsym(options, "input_channels", problem.GetOutChannels_()); // C + GenerateClangDefsym(options, "output_channels", problem.GetInChannels_()); // K + GenerateClangDefsym(options, "wei_h", problem.GetWeightsHeight_()); // R + GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth_()); // S GenerateClangDefsym(options, "pad_h", problem.GetPadH()); GenerateClangDefsym(options, "pad_w", problem.GetPadW()); GenerateClangDefsym(options, "weights_layout", 0); @@ -699,23 +699,23 @@ ConvSolution ConvAsmBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, // cppcheck-suppress unreadVariable buff_info ibuf(MemLayout::NCHW, - problem.GetBatchSize(), - problem.GetOutChannels(), + problem.GetBatchSize_(), + problem.GetOutChannels_(), AsmImgHeight(problem), AsmImgWidth(problem), 1, data_len); // cppcheck-suppress unreadVariable buff_info obuf(MemLayout::NCHW, - problem.GetBatchSize(), - problem.GetInChannels(), + problem.GetBatchSize_(), + problem.GetInChannels_(), AsmImgHeight(problem), AsmImgWidth(problem), 1, data_len); // cppcheck-suppress unreadVariable buff_info fbuf( - MemLayout::NCHW, problem.GetInChannels(), problem.GetOutChannels(), 1, 1, 1, data_len); + MemLayout::NCHW, problem.GetInChannels_(), problem.GetOutChannels_(), 1, 1, 1, data_len); GenerateClangDefsym(options, "input_n_stride", ibuf.byte_stride.nk); GenerateClangDefsym(options, "input_c_stride", ibuf.byte_stride.c); GenerateClangDefsym(options, "input_h_stride", ibuf.byte_stride.h); @@ -784,9 +784,9 @@ ConvSolution ConvAsmBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, kernel.g_wk.clear(); // gridsize kernel.g_wk.push_back(static_cast(solver::wave_size) * pcfg->GetNPartCnt()); kernel.g_wk.push_back( - divide_round_plus_inf(problem.GetOutChannels(), pcfg->GetCPerGpr() * pcfg->GetCMult())); + divide_round_plus_inf(problem.GetOutChannels_(), pcfg->GetCPerGpr() * pcfg->GetCMult())); kernel.g_wk.push_back( - divide_round_plus_inf(problem.GetInChannels(), pcfg->GetKPerGpr() * pcfg->GetKMult())); + divide_round_plus_inf(problem.GetInChannels_(), pcfg->GetKPerGpr() * pcfg->GetKMult())); kernel.kernel_file = "conv1x1wrw.s"; kernel.kernel_name = "miopenGcnAsmConv1x1WrW"; diff --git a/src/solver/conv_asm_dir_BwdWrW3x3.cpp b/src/solver/conv_asm_dir_BwdWrW3x3.cpp index 7b04c5f140..2781d25d07 100644 --- a/src/solver/conv_asm_dir_BwdWrW3x3.cpp +++ b/src/solver/conv_asm_dir_BwdWrW3x3.cpp @@ -151,39 +151,41 @@ bool PerformanceConfigAsmDirect3x3WrW::IsValid(const ConvolutionContext& ctx, assert(chunk_size != 0); if(reverse_inout == 0) { - if((problem.GetOutChannels() % (GetCPerWave() * problem.GetGroupCount()) != 0) || - (problem.GetInChannels() % (GetKPerWave() * problem.GetGroupCount()) != 0)) + if((problem.GetOutChannels_() % (GetCPerWave() * problem.GetGroupCount()) != 0) || + (problem.GetInChannels_() % (GetKPerWave() * problem.GetGroupCount()) != 0)) return false; } else { - if((problem.GetOutChannels() % (GetKPerWave() * problem.GetGroupCount()) != 0) || - (problem.GetInChannels() % (GetCPerWave() * problem.GetGroupCount()) != 0)) + if((problem.GetOutChannels_() % (GetKPerWave() * problem.GetGroupCount()) != 0) || + (problem.GetInChannels_() % (GetCPerWave() * problem.GetGroupCount()) != 0)) return false; } - if((problem.GetOutChannels() % (64 / chunk_size) != 0) && - (problem.GetInChannels() % (64 / chunk_size) != 0)) + if((problem.GetOutChannels_() % (64 / chunk_size) != 0) && + (problem.GetInChannels_() % (64 / chunk_size) != 0)) return false; - if((reverse_inout != 0 ? problem.GetInChannels() : problem.GetOutChannels()) % GetCPerWave() != + if((reverse_inout != 0 ? problem.GetInChannels_() : problem.GetOutChannels_()) % + GetCPerWave() != 0) return false; if(!(chunk_size * k_per_wave <= 64)) return false; - if((reverse_inout != 0 ? problem.GetOutChannels() : problem.GetInChannels()) % k_per_wave != 0) + if((reverse_inout != 0 ? problem.GetOutChannels_() : problem.GetInChannels_()) % k_per_wave != + 0) return false; - if(!(n_per_group <= problem.GetBatchSize())) + if(!(n_per_group <= problem.GetBatchSize_())) return false; - if(!(1 <= pipe_lines_depth && pipe_lines_depth <= std::min(problem.GetOutHeight(), 16))) + if(!(1 <= pipe_lines_depth && pipe_lines_depth <= std::min(problem.GetOutHeight_(), 16U))) return false; if((reverse_inout != 0) && !IsReverseInOutAllowed(problem)) return false; { - const int accums_cnt = (problem.GetWeightsWidth() * problem.GetWeightsHeight() * + const int accums_cnt = (problem.GetWeightsWidth_() * problem.GetWeightsHeight_() * GetCPerWave() * k_per_wave * chunk_size) / 64; assert(chunk_size); const int out_w_vec = - (problem.GetOutWidth() + elements_in_dword(problem) - 1) / elements_in_dword(problem); + (problem.GetOutWidth_() + elements_in_dword(problem) - 1) / elements_in_dword(problem); int gprs_per_line_in = (out_w_vec + chunk_size - 1) / chunk_size; if(chunk_size != 16) { @@ -196,7 +198,7 @@ bool PerformanceConfigAsmDirect3x3WrW::IsValid(const ConvolutionContext& ctx, const int gprs_per_line_out = (gprs_per_line_in > 1) ? gprs_per_line_in / problem.GetKernelStrideW() : 1; - const int lines_in = pipe_lines_depth + problem.GetWeightsHeight() - 1; + const int lines_in = pipe_lines_depth + problem.GetWeightsHeight_() - 1; const int vgprs_for_lines_in = lines_in * elements_in_dword(problem) * gprs_per_line_in; assert(problem.GetKernelStrideH()); const int lines_out = @@ -205,7 +207,7 @@ bool PerformanceConfigAsmDirect3x3WrW::IsValid(const ConvolutionContext& ctx, const int vgprs_for_division = (vgprs_for_lines_in >= 4 ? 0 : 4) + (vgprs_for_lines_out >= 3 ? 0 : 3); - const int k_group_size = problem.GetInChannels() / + const int k_group_size = problem.GetInChannels_() / (reverse_inout != 0 ? GetCPerWave() : GetKPerWave()) / problem.GetGroupCount(); const bool k_group_size_is_power_of_two = ((k_group_size & (k_group_size - 1)) == 0); @@ -225,7 +227,8 @@ bool PerformanceConfigAsmDirect3x3WrW::IsValid(const ConvolutionContext& ctx, return false; const int unroll_factor = pipe_lines_depth * (pipe_lines_depth + 2); - const int steps = std::max(0, problem.GetOutHeight() - 1 - pipe_lines_depth); + const int steps = + std::max(0, static_cast(problem.GetOutHeight_()) - 1 - pipe_lines_depth); assert(unroll_factor); const int loops = pipe_lines_depth + unroll_factor + steps % unroll_factor + 1; const int m_instr = 3 + (gprs_per_line_in + 3) / 4; @@ -234,9 +237,10 @@ bool PerformanceConfigAsmDirect3x3WrW::IsValid(const ConvolutionContext& ctx, /// information here and in all similar places across other Solvers. const bool dot2_inst_avail = (name == "gfx906" || name == "gfx908"); const bool dot2_emulate = (!dot2_inst_avail) && (elements_in_dword(problem) == 2); - const int v_instr = (k_per_wave * problem.GetWeightsHeight() * gprs_per_line_out * - problem.GetWeightsWidth() * 4 * (dot2_emulate ? 2 : 1)) / - 3 * elements_in_dword(problem); + const int v_instr = + (k_per_wave * static_cast(problem.GetWeightsHeight_()) * gprs_per_line_out * + static_cast(problem.GetWeightsWidth_()) * 4 * (dot2_emulate ? 2 : 1)) / + 3 * elements_in_dword(problem); const int exch_instr = elements_in_dword(problem) == 2 ? 3 * m_instr : 0; const int total = loops * (m_instr + v_instr + exch_instr) * elements_in_dword(problem); // instructions @@ -251,25 +255,26 @@ void PerformanceConfigAsmDirect3x3WrW::HeuristicInit(const ConvolutionContext& c { limit_wave_cnt = 0; - chunk_size = (problem.GetOutWidth() < 48) ? 8 : 16; - if((problem.GetOutChannels() % (64 / chunk_size) != 0) && - (problem.GetInChannels() % (64 / chunk_size) != 0)) + chunk_size = (problem.GetOutWidth_() < 48) ? 8 : 16; + if((problem.GetOutChannels_() % (64 / chunk_size) != 0) && + (problem.GetInChannels_() % (64 / chunk_size) != 0)) chunk_size = 16; // Fixup for correctness reverse_inout = 0; if(IsReverseInOutAllowed(problem) && - ((problem.GetOutChannels() % 4 != 0) || (problem.GetOutWidth() < 8))) + ((problem.GetOutChannels_() % 4 != 0) || (problem.GetOutWidth_() < 8))) reverse_inout = 1; const auto c_k = - problem.GetOutChannels() * problem.GetInChannels() / problem.GetGroupCount(); // C*K + problem.GetOutChannels_() * problem.GetInChannels_() / problem.GetGroupCount(); // C*K if(c_k < 256) k_per_wave = 1; else if(c_k < 16384) k_per_wave = 2; else // C*K >= 16k k_per_wave = ((chunk_size == 8) ? 2 : 4); - while((reverse_inout != 0 ? problem.GetOutChannels() : problem.GetInChannels()) % k_per_wave != + while((reverse_inout != 0 ? problem.GetOutChannels_() : problem.GetInChannels_()) % + k_per_wave != 0) k_per_wave /= 2; // Fixup for correctness @@ -281,16 +286,16 @@ void PerformanceConfigAsmDirect3x3WrW::HeuristicInit(const ConvolutionContext& c n_per_group = 2; else n_per_group = 1; - if(n_per_group > problem.GetBatchSize()) - n_per_group = problem.GetBatchSize(); // n_per_group should never be > batch size. - if(problem.GetOutWidth() >= 256 && + if(n_per_group > problem.GetBatchSize_()) + n_per_group = problem.GetBatchSize_(); // n_per_group should never be > batch size. + if(problem.GetOutWidth_() >= 256 && n_per_group > 4) // when width >= 256, n_per_group should not be > 4. n_per_group = 4; - pipe_lines_depth = (problem.GetOutHeight() <= 1) ? 1 : 2; - if((problem.GetOutHeight() < 8) && (problem.GetOutWidth() < 64)) + pipe_lines_depth = (problem.GetOutHeight_() <= 1) ? 1 : 2; + if((problem.GetOutHeight_() < 8) && (problem.GetOutWidth_() < 64)) { - pipe_lines_depth = problem.GetOutHeight(); // Special case. + pipe_lines_depth = problem.GetOutHeight_(); // Special case. } if(!IsValid(ctx, problem)) @@ -302,7 +307,7 @@ void PerformanceConfigAsmDirect3x3WrW::HeuristicInit(const ConvolutionContext& c k_per_wave = 1; pipe_lines_depth = 2; n_per_group = 1; - if(problem.GetOutChannels() % (4 * problem.GetGroupCount()) != 0) + if(problem.GetOutChannels_() % (4 * problem.GetGroupCount()) != 0) { /// (1) If reverse is Off, then both (C % c_per_wave) and (K % k_per_wave) must be 0. /// Toggling reverse swaps C and K in the condition above. @@ -379,7 +384,7 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ConvolutionContext& ctx, return false; #endif - if(name == "gfx90a" && problem.conv_problem.IsGfx90aFp16altRequired()) + if(name == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; #if WORKAROUND_SWDEV_330460 @@ -393,8 +398,8 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ConvolutionContext& ctx, && problem.GetPadH() == 1 // -p pad_h && problem.GetKernelStrideW() <= 2 // -v stride_w && problem.GetKernelStrideH() <= 2 // -u stride_h - && problem.GetWeightsWidth() == 3 // -x S wei_w - && problem.GetWeightsHeight() == 3 // -y R wei_h + && problem.GetWeightsWidth_() == 3 // -x S wei_w + && problem.GetWeightsHeight_() == 3 // -y R wei_h && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 && problem.GetBias() == 0 @@ -405,28 +410,28 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ConvolutionContext& ctx, if(problem.IsFp16() && (StartsWith(name, "gfx8") // Not supported. - || problem.GetBatchSize() % 2 != 0)) /// \todo Initial version. + || problem.GetBatchSize_() % 2 != 0)) /// \todo Initial version. return false; // Check limits: - const auto h_w = static_cast(problem.GetOutHeight()) * problem.GetOutWidth(); - const auto r_s = static_cast(problem.GetWeightsHeight()) * problem.GetWeightsWidth(); - const auto c_h_w = static_cast(problem.GetOutChannels()) * h_w; // C*H*W - const auto k_h_w = static_cast(problem.GetInChannels()) * h_w; // K*H*W - const auto c_r_s = static_cast(problem.GetOutChannels()) * r_s; // C*R*S - const auto k_r_s = static_cast(problem.GetInChannels()) * r_s; // K*R*S - const auto n_c_h_w = static_cast(problem.GetBatchSize()) * c_h_w; // N*C*H*W - const auto n_k_h_w = static_cast(problem.GetBatchSize()) * k_h_w; // N*K*H*W - const auto c_k_r_s = static_cast(problem.GetOutChannels()) * k_r_s; // C*K*R*S - ok = problem.GetOutWidth() > 0 - && problem.GetOutWidth() <= 512 + const auto h_w = static_cast(problem.GetOutHeight_()) * problem.GetOutWidth_(); + const auto r_s = static_cast(problem.GetWeightsHeight_()) * problem.GetWeightsWidth_(); + const auto c_h_w = static_cast(problem.GetOutChannels_()) * h_w; // C*H*W + const auto k_h_w = static_cast(problem.GetInChannels_()) * h_w; // K*H*W + const auto c_r_s = static_cast(problem.GetOutChannels_()) * r_s; // C*R*S + const auto k_r_s = static_cast(problem.GetInChannels_()) * r_s; // K*R*S + const auto n_c_h_w = static_cast(problem.GetBatchSize_()) * c_h_w; // N*C*H*W + const auto n_k_h_w = static_cast(problem.GetBatchSize_()) * k_h_w; // N*K*H*W + const auto c_k_r_s = static_cast(problem.GetOutChannels_()) * k_r_s; // C*K*R*S + ok = problem.GetOutWidth_() > 0 + && problem.GetOutWidth_() <= 512 && (IsReverseInOutAllowed(problem) - ? ((problem.GetOutChannels() % (4 * problem.GetGroupCount()) == 0) || (problem.GetInChannels() % (4 * problem.GetGroupCount()) == 0)) - : (problem.GetOutChannels() % (4 * problem.GetGroupCount()) == 0)) - && problem.GetOutHeight() < std::pow(2, 16) // -H H img_h - && problem.GetBatchSize() < std::pow(2, 16) // -n N batch_size - && problem.GetOutChannels() < std::pow(2, 16) // -c C input_channels - && problem.GetInChannels() < std::pow(2, 16) // -k K output_channels + ? ((problem.GetOutChannels_() % (4 * problem.GetGroupCount()) == 0) || (problem.GetInChannels_() % (4 * problem.GetGroupCount()) == 0)) + : (problem.GetOutChannels_() % (4 * problem.GetGroupCount()) == 0)) + && problem.GetOutHeight_() < std::pow(2, 16) // -H H img_h + && problem.GetBatchSize_() < std::pow(2, 16) // -n N batch_size + && problem.GetOutChannels_() < std::pow(2, 16) // -c C input_channels + && problem.GetInChannels_() < std::pow(2, 16) // -k K output_channels && c_h_w < std::pow(2, 22) && k_h_w < std::pow(2, 22) && c_r_s < std::pow(2, 22) @@ -444,14 +449,14 @@ ConvSolution ConvAsmBwdWrW3x3::GetSolution(const ConvolutionContext& ctx, ConvSolution result; std::ostringstream options; GenerateClangDefsym(options, "elements_in_dword", (problem.IsFp16()) ? 2 : 1); - GenerateClangDefsym(options, "batch_size", problem.GetBatchSize()); // N - GenerateClangDefsym(options, "img_h", problem.GetOutHeight()); // H - GenerateClangDefsym(options, "img_w", problem.GetOutWidth()); // W + GenerateClangDefsym(options, "batch_size", problem.GetBatchSize_()); // N + GenerateClangDefsym(options, "img_h", problem.GetOutHeight_()); // H + GenerateClangDefsym(options, "img_w", problem.GetOutWidth_()); // W // Note that problem.n_outputs and problem.n_inputs are swapped for backward convolutions. - GenerateClangDefsym(options, "input_channels", problem.GetOutChannels()); // C - GenerateClangDefsym(options, "output_channels", problem.GetInChannels()); // K - GenerateClangDefsym(options, "wei_h", problem.GetWeightsHeight()); // R - GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth()); // S + GenerateClangDefsym(options, "input_channels", problem.GetOutChannels_()); // C + GenerateClangDefsym(options, "output_channels", problem.GetInChannels_()); // K + GenerateClangDefsym(options, "wei_h", problem.GetWeightsHeight_()); // R + GenerateClangDefsym(options, "wei_w", problem.GetWeightsWidth_()); // S GenerateClangDefsym(options, "pad_h", problem.GetPadH()); GenerateClangDefsym(options, "pad_w", problem.GetPadW()); GenerateClangDefsym(options, "stride_h", problem.GetKernelStrideH()); @@ -498,7 +503,7 @@ ConvSolution ConvAsmBwdWrW3x3::GetSolution(const ConvolutionContext& ctx, GenerateClangDefsym(options, "group_counts", problem.GetGroupCount()); const int k_group_size = - problem.GetInChannels() / + problem.GetInChannels_() / (pcfg->reverse_inout != 0 ? pcfg->GetCPerWave() : pcfg->GetKPerWave()) / problem.GetGroupCount(); const bool k_group_size_is_power_of_two = ((k_group_size & (k_group_size - 1)) == 0); @@ -518,15 +523,15 @@ ConvSolution ConvAsmBwdWrW3x3::GetSolution(const ConvolutionContext& ctx, if(pcfg->GetReverseInout() == 0) { - kernel.g_wk.push_back(problem.GetOutChannels() / pcfg->GetCPerWave() / + kernel.g_wk.push_back(problem.GetOutChannels_() / pcfg->GetCPerWave() / problem.GetGroupCount()); - kernel.g_wk.push_back(problem.GetInChannels() / pcfg->GetKPerWave()); + kernel.g_wk.push_back(problem.GetInChannels_() / pcfg->GetKPerWave()); } else { - kernel.g_wk.push_back(problem.GetOutChannels() / pcfg->GetKPerWave() / + kernel.g_wk.push_back(problem.GetOutChannels_() / pcfg->GetKPerWave() / problem.GetGroupCount()); - kernel.g_wk.push_back(problem.GetInChannels() / pcfg->GetCPerWave()); + kernel.g_wk.push_back(problem.GetInChannels_() / pcfg->GetCPerWave()); } kernel.kernel_file = "conv3x3wrw.s"; diff --git a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp index a0360b4aba..dbfdb0b69c 100644 --- a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp @@ -44,21 +44,21 @@ static inline bool FindImplicitGemmDynamicKernelBwd(const ProblemDescription& pr // TODO: add more dynamic kernel to expand support range, and update this function // clang-format off // refer to ProblemInterpreter, in bwd most dimension is reversed - int hi = problem.GetOutHeight(); - int wi = problem.GetOutWidth(); - int n = problem.GetBatchSize(); - int k = problem.GetInChannels(); - int c = problem.GetOutChannels(); - int ho = problem.GetInHeight(); - int wo = problem.GetInWidth(); - int stride_h = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1; - int stride_w = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1; - int dilation_h = problem.GetWeightsHeight() > 1? problem.GetDilationH() : 1; - int dilation_w = problem.GetWeightsWidth() > 1? problem.GetDilationW() : 1; + int hi = problem.GetOutHeight_(); + int wi = problem.GetOutWidth_(); + int n = problem.GetBatchSize_(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); + int stride_h = problem.GetInHeight_() > 1 ? problem.GetKernelStrideH() : 1; + int stride_w = problem.GetInWidth_() > 1 ? problem.GetKernelStrideW() : 1; + int dilation_h = problem.GetWeightsHeight_() > 1? problem.GetDilationH() : 1; + int dilation_w = problem.GetWeightsWidth_() > 1? problem.GetDilationW() : 1; int pad_h = problem.GetPadH(); int pad_w = problem.GetPadW(); - int y = problem.GetWeightsHeight(); - int x = problem.GetWeightsWidth(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); int gcd_stride_dilation_h = gcd(stride_h, dilation_h); int gcd_stride_dilation_w = gcd(stride_w, dilation_w); diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp index d22d79f8c6..caa26d1f23 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp @@ -790,21 +790,21 @@ FindImplicitGemmGtcDynamicBwdKernel(const ProblemDescription& problem) // so far, "group" is only supported by bwd fp16 kernels const auto group = problem.IsFp16() ? problem.GetGroupCount() : 1; - const auto hi = problem.GetOutHeight(); - const auto wi = problem.GetOutWidth(); - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels() / group; - const auto c = problem.GetOutChannels() / group; - const auto ho = problem.GetInHeight(); - const auto wo = problem.GetInWidth(); - const auto stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1; - const auto stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1; - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const int hi = problem.GetOutHeight_(); + const int wi = problem.GetOutWidth_(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_() / group; + const int c = problem.GetOutChannels_() / group; + const int ho = problem.GetInHeight_(); + const int wo = problem.GetInWidth_(); + const auto stride_h = problem.GetOutHeight_() > 1 ? problem.GetKernelStrideH() : 1; + const auto stride_w = problem.GetOutWidth_() > 1 ? problem.GetKernelStrideW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto gcd_stride_dilation_h = gcd(stride_h, dilation_h); const auto gcd_stride_dilation_w = gcd(stride_w, dilation_w); diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp index 1613caabab..41a2b018fa 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp @@ -408,21 +408,21 @@ GetImplicitGemmGtcDynamicBwdXdlopsNHWCKernel( const PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC& config) { const auto group = problem.GetGroupCount(); - const auto hi = problem.GetOutHeight(); - const auto wi = problem.GetOutWidth(); - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto ho = problem.GetInHeight(); - const auto wo = problem.GetInWidth(); - const auto stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1; - const auto stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1; - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const int hi = problem.GetOutHeight_(); + const int wi = problem.GetOutWidth_(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int ho = problem.GetInHeight_(); + const int wo = problem.GetInWidth_(); + const auto stride_h = problem.GetOutHeight_() > 1 ? problem.GetKernelStrideH() : 1; + const auto stride_w = problem.GetOutWidth_() > 1 ? problem.GetKernelStrideW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto gcd_stride_dilation_h = gcd(stride_h, dilation_h); const auto gcd_stride_dilation_w = gcd(stride_w, dilation_w); @@ -594,21 +594,21 @@ void PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::HeuristicInit( #endif const auto group = problem.GetGroupCount(); - const auto hi = problem.GetOutHeight(); - const auto wi = problem.GetOutWidth(); - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto ho = problem.GetInHeight(); - const auto wo = problem.GetInWidth(); - const auto stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1; - const auto stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1; - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const int hi = problem.GetOutHeight_(); + const int wi = problem.GetOutWidth_(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int ho = problem.GetInHeight_(); + const int wo = problem.GetInWidth_(); + const auto stride_h = problem.GetOutHeight_() > 1 ? problem.GetKernelStrideH() : 1; + const auto stride_w = problem.GetOutWidth_() > 1 ? problem.GetKernelStrideW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto gcd_stride_dilation_h = gcd(stride_h, dilation_h); const auto gcd_stride_dilation_w = gcd(stride_w, dilation_w); @@ -802,22 +802,22 @@ bool PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC::IsValid( return false; const auto group = problem.GetGroupCount(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1; - const auto stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1; - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const auto stride_h = problem.GetOutHeight_() > 1 ? problem.GetKernelStrideH() : 1; + const auto stride_w = problem.GetOutWidth_() > 1 ? problem.GetKernelStrideW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); - const auto hi = problem.GetOutHeight(); - const auto wi = problem.GetOutWidth(); - const auto n = problem.GetBatchSize(); - const auto ho = problem.GetInHeight(); - const auto wo = problem.GetInWidth(); + const int hi = problem.GetOutHeight_(); + const int wi = problem.GetOutWidth_(); + const int n = problem.GetBatchSize_(); + const int ho = problem.GetInHeight_(); + const int wo = problem.GetInWidth_(); auto splits_4G = igemm_split_batch_size( hi, wi, ho, wo, n, k, c, miopen::GetTypeSize(problem.GetInDataType())); @@ -916,7 +916,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable( if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS_NHWC{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; #if WORKAROUND_ISSUE_1979 @@ -949,13 +949,13 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable( if(target.Xnack() && *target.Xnack()) return false; // NOLINT (readability-simplify-boolean-expr) - if(0 == igemm_split_batch_size(problem.GetOutHeight(), - problem.GetOutWidth(), - problem.GetInHeight(), - problem.GetInWidth(), - problem.GetBatchSize(), - problem.GetInChannels(), - problem.GetOutChannels(), + if(0 == igemm_split_batch_size(problem.GetOutHeight_(), + problem.GetOutWidth_(), + problem.GetInHeight_(), + problem.GetInWidth_(), + problem.GetBatchSize_(), + problem.GetInChannels_(), + problem.GetOutChannels_(), miopen::GetTypeSize(problem.GetInDataType()))) return false; { @@ -976,15 +976,15 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable( size_t ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetWorkspaceSize( const ConvolutionContext& ctx, const ProblemDescription& problem) const { - const auto hi = problem.GetOutHeight(); - const auto wi = problem.GetOutWidth(); - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto ho = problem.GetInHeight(); - const auto wo = problem.GetInWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int hi = problem.GetOutHeight_(); + const int wi = problem.GetOutWidth_(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int ho = problem.GetInHeight_(); + const int wo = problem.GetInWidth_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto group = problem.GetGroupCount(); const auto is_nchw = problem.IsLayoutDefault(); @@ -1060,7 +1060,7 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetSolution( kernel.l_wk.push_back(1); const auto isGfx90aFp16altSupport = - (ctx.GetStream().GetDeviceName() == "gfx90a") && problem.conv_problem.IsFp16(); + (ctx.GetStream().GetDeviceName() == "gfx90a") && problem.IsFp16(); const auto is_nchw = problem.IsLayoutDefault(); @@ -1102,20 +1102,20 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetSolution( GenerateClangDefsym(opts_1, "igemm_bwd_fp16_alt_impl", 1); result.construction_params[1].comp_options = opts_1.str(); if(miopen::IsLogging(LoggingLevel::Info2)) - msg << ", fp16_alt:" << problem.conv_problem.GetConv().attribute.gfx90aFp16alt.GetBwd(); + msg << ", fp16_alt:" << problem.GetConv().attribute.gfx90aFp16alt.GetBwd(); } if(is_nchw) { - const auto hi = problem.GetOutHeight(); - const auto wi = problem.GetOutWidth(); - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto ho = problem.GetInHeight(); - const auto wo = problem.GetInWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int hi = problem.GetOutHeight_(); + const int wi = problem.GetOutWidth_(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int ho = problem.GetInHeight_(); + const int wo = problem.GetInWidth_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto group = problem.GetGroupCount(); TransposeSolutionNhwc2Default trans_input(ctx, problem.GetOutDataType(), n, c, hi, wi); diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp index 4f4ebc87f8..081e12a532 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp @@ -1351,19 +1351,19 @@ static std::tuple 1 ? problem.GetKernelStrideH() : 1; - const auto stride_w = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1; - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const int n = problem.GetBatchSize_(); + const int c = problem.GetInChannels_(); + const int k = problem.GetOutChannels_(); + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); + const auto stride_h = problem.GetInHeight_() > 1 ? problem.GetKernelStrideH() : 1; + const auto stride_w = problem.GetInWidth_() > 1 ? problem.GetKernelStrideW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto& gemm_m = k; const auto gemm_n = n * ho * wo; @@ -1533,8 +1533,8 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext } #if WORKAROUND_SWDEV_306318 - if((problem.GetWeightsHeight() == 1) && (problem.GetWeightsWidth() == 1) && - (problem.GetInChannels() % 8 != 0)) + if((problem.GetWeightsHeight_() == 1) && (problem.GetWeightsWidth_() == 1) && + (problem.GetInChannels_() % 8 != 0)) if(!miopen::IsEnabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS{})) return false; #endif diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp index afed43ce31..c52372b6d2 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp @@ -261,15 +261,15 @@ GetImplicitGemmGtcDynamicFwdDlopsNCHWCKernel( const ProblemDescription& problem, const PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC& config) { - const auto n = problem.GetBatchSize(); - const auto k = problem.GetOutChannels() * config.vector_c; - const auto ho = problem.GetOutHeight(); - const auto wo = problem.GetOutWidth(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetOutChannels_() * config.vector_c; + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); const auto group = problem.GetGroupCount(); - const auto hi = problem.GetInHeight(); - const auto wi = problem.GetInWidth(); - const auto c = problem.GetInChannels(); + const int hi = problem.GetInHeight_(); + const int wi = problem.GetInWidth_(); + const int c = problem.GetInChannels_(); auto splits_4G = igemm_split_batch_size( hi, wi, ho, wo, n, k, c, miopen::GetTypeSize(problem.GetInDataType())); @@ -370,13 +370,13 @@ void PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC::HeuristicInit( } #endif - const auto n = problem.GetBatchSize(); - const auto c = problem.GetInChannels(); - const auto k = problem.GetOutChannels(); - const auto ho = problem.GetOutHeight(); - const auto wo = problem.GetOutWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int n = problem.GetBatchSize_(); + const int c = problem.GetInChannels_(); + const int k = problem.GetOutChannels_(); + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto group = problem.GetGroupCount(); size_t gemm_m = static_cast(n) * ho * wo; @@ -478,17 +478,17 @@ bool PerformanceConfigAsmImplicitGemmGTCFwdDlopsNCHWC::IsValid( (problem.IsNCHWc_CHWNc() && tensor_layout == "nchwc_cyxkc"))) return false; - const auto c = problem.GetInChannels(); - const auto k = problem.GetOutChannels(); + const int c = problem.GetInChannels_(); + const int k = problem.GetOutChannels_(); const auto group = problem.GetGroupCount(); - const auto stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1; - const auto stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1; - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const auto stride_h = problem.GetOutHeight_() > 1 ? problem.GetKernelStrideH() : 1; + const auto stride_w = problem.GetOutWidth_() > 1 ? problem.GetKernelStrideW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); bool unit_conv = (x == 1) && (y == 1) && (stride_h == 1) && (stride_w == 1) && (dilation_h == 1) && (dilation_w == 1) && (pad_h == 0) && (pad_w == 0); @@ -574,13 +574,13 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable( if(target.Xnack() && *target.Xnack()) return false; // NOLINT (readability-simplify-boolean-expr) - if(0 == igemm_split_batch_size(problem.GetInHeight(), - problem.GetInWidth(), - problem.GetOutHeight(), - problem.GetOutWidth(), - problem.GetBatchSize(), - problem.GetOutChannels(), - problem.GetInChannels(), + if(0 == igemm_split_batch_size(problem.GetInHeight_(), + problem.GetInWidth_(), + problem.GetOutHeight_(), + problem.GetOutWidth_(), + problem.GetBatchSize_(), + problem.GetOutChannels_(), + problem.GetInChannels_(), miopen::GetTypeSize(problem.GetInDataType()))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp index 5affd97b1e..601b2e1211 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp @@ -337,15 +337,15 @@ GetImplicitGemmGtcDynamicFwdXdlopsNHWCKernel( const ProblemDescription& problem, const PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC& config) { - const auto n = problem.GetBatchSize(); - const auto k = problem.GetOutChannels(); - const auto ho = problem.GetOutHeight(); - const auto wo = problem.GetOutWidth(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetOutChannels_(); + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); const auto group = problem.GetGroupCount(); - const auto hi = problem.GetInHeight(); - const auto wi = problem.GetInWidth(); - const auto c = problem.GetInChannels(); + const int hi = problem.GetInHeight_(); + const int wi = problem.GetInWidth_(); + const int c = problem.GetInChannels_(); auto splits_4G = igemm_split_batch_size( hi, wi, ho, wo, n, k, c, miopen::GetTypeSize(problem.GetInDataType())); @@ -490,19 +490,19 @@ void PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC::HeuristicInit( } #endif - const auto n = problem.GetBatchSize(); - const auto c = problem.GetInChannels(); - const auto k = problem.GetOutChannels(); - const auto ho = problem.GetOutHeight(); - const auto wo = problem.GetOutWidth(); - const auto stride_h = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1; - const auto stride_w = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1; - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const int n = problem.GetBatchSize_(); + const int c = problem.GetInChannels_(); + const int k = problem.GetOutChannels_(); + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); + const auto stride_h = problem.GetInHeight_() > 1 ? problem.GetKernelStrideH() : 1; + const auto stride_w = problem.GetInWidth_() > 1 ? problem.GetKernelStrideW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto group = problem.GetGroupCount(); size_t gemm_m = static_cast(n) * ho * wo; @@ -676,23 +676,23 @@ bool PerformanceConfigAsmImplicitGemmGTCFwdXdlopsNHWC::IsValid( if(problem.IsFp16() && gemm_k_global_split != 0 && vector_store != 1) return false; - const auto c = problem.GetInChannels(); - const auto k = problem.GetOutChannels(); + const int c = problem.GetInChannels_(); + const int k = problem.GetOutChannels_(); const auto group = problem.GetGroupCount(); - const auto stride_h = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1; - const auto stride_w = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1; - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const auto stride_h = problem.GetInHeight_() > 1 ? problem.GetKernelStrideH() : 1; + const auto stride_w = problem.GetInWidth_() > 1 ? problem.GetKernelStrideW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); - - const auto n = problem.GetBatchSize(); - const auto ho = problem.GetOutHeight(); - const auto wo = problem.GetOutWidth(); - const auto hi = problem.GetInHeight(); - const auto wi = problem.GetInWidth(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); + + const int n = problem.GetBatchSize_(); + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); + const int hi = problem.GetInHeight_(); + const int wi = problem.GetInWidth_(); auto splits_4G = igemm_split_batch_size( hi, wi, ho, wo, n, k, c, miopen::GetTypeSize(problem.GetInDataType())); if(problem.IsFp16() && gemm_k_global_split != 0 && vector_store != 1 && splits_4G > 1) @@ -795,15 +795,15 @@ ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::Search(const ConvolutionContext& ctx size_t ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetWorkspaceSize( const ConvolutionContext& ctx, const ProblemDescription& problem) const { - const auto hi = problem.GetInHeight(); - const auto wi = problem.GetInWidth(); - const auto n = problem.GetBatchSize(); - const auto k = problem.GetOutChannels(); - const auto c = problem.GetInChannels(); - const auto ho = problem.GetOutHeight(); - const auto wo = problem.GetOutWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int hi = problem.GetInHeight_(); + const int wi = problem.GetInWidth_(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetOutChannels_(); + const int c = problem.GetInChannels_(); + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto group = problem.GetGroupCount(); const auto is_nchw = problem.IsLayoutDefault(); size_t workspace_size = 0; @@ -854,7 +854,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable( if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS_NHWC{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; #if WORKAROUND_ISSUE_1979 @@ -887,13 +887,13 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable( if(target.Xnack() && *target.Xnack()) return false; // NOLINT (readability-simplify-boolean-expr) - if(0 == igemm_split_batch_size(problem.GetInHeight(), - problem.GetInWidth(), - problem.GetOutHeight(), - problem.GetOutWidth(), - problem.GetBatchSize(), - problem.GetOutChannels(), - problem.GetInChannels(), + if(0 == igemm_split_batch_size(problem.GetInHeight_(), + problem.GetInWidth_(), + problem.GetOutHeight_(), + problem.GetOutWidth_(), + problem.GetBatchSize_(), + problem.GetOutChannels_(), + problem.GetInChannels_(), miopen::GetTypeSize(problem.GetInDataType()))) return false; @@ -945,7 +945,7 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetSolution( kernel.l_wk.push_back(1); const auto isGfx90aFp16altSupport = - (ctx.GetStream().GetDeviceName() == "gfx90a") && problem.conv_problem.IsFp16(); + (ctx.GetStream().GetDeviceName() == "gfx90a") && problem.IsFp16(); const auto is_nchw = problem.IsLayoutDefault(); @@ -987,20 +987,20 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetSolution( GenerateClangDefsym(opts_1, "igemm_fwd_fp16_alt_impl", 1); result.construction_params[1].comp_options = opts_1.str(); if(miopen::IsLogging(LoggingLevel::Info2)) - msg << ", fp16_alt:" << problem.conv_problem.GetConv().attribute.gfx90aFp16alt.GetFwd(); + msg << ", fp16_alt:" << problem.GetConv().attribute.gfx90aFp16alt.GetFwd(); } if(is_nchw) { - const auto hi = problem.GetInHeight(); - const auto wi = problem.GetInWidth(); - const auto n = problem.GetBatchSize(); - const auto k = problem.GetOutChannels(); - const auto c = problem.GetInChannels(); - const auto ho = problem.GetOutHeight(); - const auto wo = problem.GetOutWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int hi = problem.GetInHeight_(); + const int wi = problem.GetInWidth_(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetOutChannels_(); + const int c = problem.GetInChannels_(); + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto group = problem.GetGroupCount(); TransposeSolutionDefault2Nhwc trans_input(ctx, problem.GetInDataType(), n, c, hi, wi); diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index fdf0b64522..35de228c45 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -334,13 +334,10 @@ GetImplicitGemmGtcDynamicWrwXdlopsNHWCKernel( const ProblemDescription& problem, const PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC& config) { - // const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - // const auto ho = problem.GetInHeight(); - // const auto wo = problem.GetInWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto group = problem.GetGroupCount(); // c need to be carefully padded @@ -561,14 +558,14 @@ void PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC::HeuristicInit( } #endif - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto stride_h = problem.GetKernelStrideH(); const auto stride_w = problem.GetKernelStrideW(); - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); const auto group = problem.GetGroupCount(); @@ -765,14 +762,14 @@ bool PerformanceConfigAsmImplicitGemmGTCWrwXdlopsNHWC::IsValid( vector_store != 1) return false; - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto stride_h = problem.GetKernelStrideH(); const auto stride_w = problem.GetKernelStrideW(); - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); const auto precision = @@ -847,7 +844,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_GTC_XDLOPS_NHWC{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; #if WORKAROUND_ISSUE_1979 @@ -880,13 +877,13 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( if(target.Xnack() && *target.Xnack()) return false; // NOLINT (readability-simplify-boolean-expr) - if(0 == igemm_split_batch_size(problem.GetOutHeight(), - problem.GetOutWidth(), - problem.GetInHeight(), - problem.GetInWidth(), - problem.GetBatchSize(), - problem.GetInChannels(), - problem.GetOutChannels(), + if(0 == igemm_split_batch_size(problem.GetOutHeight_(), + problem.GetOutWidth_(), + problem.GetInHeight_(), + problem.GetInWidth_(), + problem.GetBatchSize_(), + problem.GetInChannels_(), + problem.GetOutChannels_(), miopen::GetTypeSize(problem.GetInDataType()))) return false; @@ -907,27 +904,27 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( } static std::vector -ComputeDynamicIGemmWrwKernelArgsNHWC(const conv::ProblemDescription& conv_problem, +ComputeDynamicIGemmWrwKernelArgsNHWC(const conv::ProblemDescription& problem, const int gemm_k_global_splits, const int gemm_k_per_wg, const int splits_4G) { - int hi = conv_problem.GetOutHeight(); - int wi = conv_problem.GetOutWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetInChannels(); - int c = conv_problem.GetOutChannels(); - int ho = conv_problem.GetInHeight(); - int wo = conv_problem.GetInWidth(); - int stride_h = conv_problem.GetOutHeight() > 1 ? conv_problem.GetKernelStrideH() : 1; - int stride_w = conv_problem.GetOutWidth() > 1 ? conv_problem.GetKernelStrideW() : 1; - int dilation_h = conv_problem.GetWeightsHeight() > 1 ? conv_problem.GetDilationH() : 1; - int dilation_w = conv_problem.GetWeightsWidth() > 1 ? conv_problem.GetDilationW() : 1; - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetWeightsHeight(); - int x = conv_problem.GetWeightsWidth(); - int group = conv_problem.GetGroupCount(); + int hi = problem.GetOutHeight_(); + int wi = problem.GetOutWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); + int stride_h = problem.GetOutHeight_() > 1 ? problem.GetKernelStrideH() : 1; + int stride_w = problem.GetOutWidth_() > 1 ? problem.GetKernelStrideW() : 1; + int dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + int dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); + int group = problem.GetGroupCount(); std::vector opArgs; opArgs.emplace_back(0); // placeholder @@ -958,15 +955,15 @@ ComputeDynamicIGemmWrwKernelArgsNHWC(const conv::ProblemDescription& conv_proble size_t ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetWorkspaceSize( const ConvolutionContext& ctx, const ProblemDescription& problem) const { - const auto hi = problem.GetOutHeight(); - const auto wi = problem.GetOutWidth(); - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto ho = problem.GetInHeight(); - const auto wo = problem.GetInWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int hi = problem.GetOutHeight_(); + const int wi = problem.GetOutWidth_(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int ho = problem.GetInHeight_(); + const int wo = problem.GetInWidth_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto group = problem.GetGroupCount(); const auto is_nchw = problem.IsLayoutDefault(); @@ -1025,15 +1022,15 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( std::string kernel_name = config.ToKernelName(ctx); - const auto hi = problem.GetOutHeight(); - const auto wi = problem.GetOutWidth(); - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto ho = problem.GetInHeight(); - const auto wo = problem.GetInWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int hi = problem.GetOutHeight_(); + const int wi = problem.GetOutWidth_(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int ho = problem.GetInHeight_(); + const int wo = problem.GetInWidth_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto group = problem.GetGroupCount(); auto splits_4G = igemm_split_batch_size( @@ -1052,9 +1049,9 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( gemm_k_global_splits = 1; // compute workload for 1 workgroup and update gemmk splits (remove the ones compute 0 data) - size_t gemmk = integer_divide_ceil(static_cast(problem.GetBatchSize() / splits_4G), + size_t gemmk = integer_divide_ceil(static_cast(problem.GetBatchSize_() / splits_4G), min_n_per_block) * - problem.GetInHeight() * problem.GetInWidth(); + problem.GetInHeight_() * problem.GetInWidth_(); size_t gemmk_per_wg = integer_divide_ceil(gemmk, gemm_k_global_splits); gemmk_per_wg = (gemmk_per_wg + nb_per_block - 1) / nb_per_block * nb_per_block; @@ -1074,10 +1071,9 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); - const auto& conv_problem = problem.conv_problem; - const auto isFp16 = conv_problem.IsFp16(); + const auto isFp16 = problem.IsFp16(); const auto isGfx90aFp16altSupport = (ctx.GetStream().GetDeviceName() == "gfx90a") && isFp16; - const bool need_cast = (conv_problem.IsBfp16() && gemm_k_global_splits >= 1) || + const bool need_cast = (problem.IsBfp16() && gemm_k_global_splits >= 1) || (isFp16 && gemm_k_global_splits >= 1 && (config.tensor_b_thread_lengths[3] == 1 || config.vector_store == 1)); @@ -1121,13 +1117,13 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( GenerateClangDefsym(opts_1, "igemm_wrw_fp16_alt_impl", 1); result.construction_params[1].comp_options = opts_1.str(); if(miopen::IsLogging(LoggingLevel::Info2)) - msg << ", fp16_alt:" << problem.conv_problem.GetConv().attribute.gfx90aFp16alt.GetWrW(); + msg << ", fp16_alt:" << problem.GetConv().attribute.gfx90aFp16alt.GetWrW(); } - const auto lowp_quant = conv_problem.GetConv().lowp_quant; + const auto lowp_quant = problem.GetConv().lowp_quant; auto opArgs = ComputeDynamicIGemmWrwKernelArgsNHWC( - conv_problem, gemm_k_global_splits, gemmk_per_wg, splits_4G); + problem, gemm_k_global_splits, gemmk_per_wg, splits_4G); std::vector> opArgsTrans; size_t trans_input_offset = 0; size_t trans_input_size = 0; @@ -1214,9 +1210,8 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( const int kID_trans_start = isGfx90aFp16altSupport ? 2 : 1; - const TensorDescriptor cast_desc(miopenFloat, - problem.conv_problem.GetWeights().GetLengths(), - problem.conv_problem.GetWeights().GetStrides()); + const TensorDescriptor cast_desc( + miopenFloat, problem.GetWeights().GetLengths(), problem.GetWeights().GetStrides()); auto null_buf = shared{}; if(need_cast) diff --git a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp index 82efae8abc..7846179475 100644 --- a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp @@ -130,10 +130,10 @@ static inline int GetImplicitGemmV4R1DynamicGridSize(const ProblemDescription& p const auto& N1 = config.GemmNRepeat; const auto& N2 = config.GemmNPerThreadSubC; - const auto n = problem.GetBatchSize(); - const auto k = problem.GetOutChannels(); - const auto ho = problem.GetOutHeight(); - const auto wo = problem.GetOutWidth(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetOutChannels_(); + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); const auto& b = (static_cast(n) * ho * wo) / (static_cast(N1) * N2); const auto& b_per_block = config.BPerBlock; @@ -343,7 +343,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd_1x1::IsApplicable(const ExecutionContext& if(problem.GetGroupCount() != 1) return false; - if((problem.GetWeightsHeight() != 1) || (problem.GetWeightsWidth() != 1)) + if((problem.GetWeightsHeight_() != 1) || (problem.GetWeightsWidth_() != 1)) return false; if(!problem.IsLayoutDefault()) diff --git a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp index 58f6bffbb2..3594b26277 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp @@ -418,11 +418,11 @@ static inline int if_gemm_k_global_split(const ProblemDescription& problem, const int b) { int gemm_k_global_split = 0; - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto& gemm_m = k; const auto gemm_n = c * y * x; @@ -444,27 +444,27 @@ static inline int if_gemm_k_global_split(const ProblemDescription& problem, } inline std::vector -ComputeDynamicIGemmWrwKernelArgs(const conv::ProblemDescription& conv_problem, +ComputeDynamicIGemmWrwKernelArgs(const conv::ProblemDescription& problem, const int log2_gemm_k_global_splits, const int nxb, const int gemm_k_per_block) { - int hi = conv_problem.GetOutHeight(); - int wi = conv_problem.GetOutWidth(); - int n = conv_problem.GetInBatchSize(); - int k = conv_problem.GetInChannels(); - int c = conv_problem.GetOutChannels(); - int ho = conv_problem.GetInHeight(); - int wo = conv_problem.GetInWidth(); - int stride_h = conv_problem.GetOutHeight() > 1 ? conv_problem.GetKernelStrideH() : 1; - int stride_w = conv_problem.GetOutWidth() > 1 ? conv_problem.GetKernelStrideW() : 1; - int dilation_h = conv_problem.GetWeightsHeight() > 1 ? conv_problem.GetDilationH() : 1; - int dilation_w = conv_problem.GetWeightsWidth() > 1 ? conv_problem.GetDilationW() : 1; - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetWeightsHeight(); - int x = conv_problem.GetWeightsWidth(); - int group = conv_problem.GetGroupCount(); + int hi = problem.GetOutHeight_(); + int wi = problem.GetOutWidth_(); + int n = problem.GetInBatchSize_(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); + int stride_h = problem.GetOutHeight_() > 1 ? problem.GetKernelStrideH() : 1; + int stride_w = problem.GetOutWidth_() > 1 ? problem.GetKernelStrideW() : 1; + int dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + int dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); + int group = problem.GetGroupCount(); int dim_b = (ho * wo + nxb - 1) / nxb * nxb; @@ -534,17 +534,17 @@ static inline std::tuple // gemm_k_split FindImplicitGemmWrwGTCDynamicXdlopsKernel(const ProblemDescription& problem) { - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto ho = problem.GetInHeight(); - const auto wo = problem.GetInWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int ho = problem.GetInHeight_(); + const int wo = problem.GetInWidth_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto stride_h = problem.GetKernelStrideH(); const auto stride_w = problem.GetKernelStrideW(); - const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + const auto dilation_h = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + const auto dilation_w = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; const auto pad_h = problem.GetPadH(); const auto pad_w = problem.GetPadW(); const auto precision = problem.IsFp16() ? miopenHalf : miopenFloat; @@ -800,10 +800,10 @@ ConvAsmImplicitGemmGTCDynamicWrwXdlops::GetWorkspaceSize(const ExecutionContext& return 0; else { - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto ngroups = problem.GetGroupCount(); return static_cast(ngroups) * (k / ngroups) * (c / ngroups) * y * x * @@ -817,7 +817,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_GTC_XDLOPS{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; const auto device_name = ctx.GetStream().GetDeviceName(); @@ -920,13 +920,12 @@ ConvAsmImplicitGemmGTCDynamicWrwXdlops::GetSolution(const ExecutionContext& ctx, result.construction_params.push_back(kernel); - const auto& conv_problem = problem.conv_problem; - const auto& lowp_quant = problem.conv_problem.GetConv().lowp_quant; + const auto& lowp_quant = problem.GetConv().lowp_quant; - auto opArgs = ComputeDynamicIGemmWrwKernelArgs( - conv_problem, log2_gemm_k_global_splits, nxb, gemm_k_per_block); + auto opArgs = + ComputeDynamicIGemmWrwKernelArgs(problem, log2_gemm_k_global_splits, nxb, gemm_k_per_block); - if(conv_problem.IsFp32()) + if(problem.IsFp32()) { result.invoker_factory = [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { @@ -957,11 +956,10 @@ ConvAsmImplicitGemmGTCDynamicWrwXdlops::GetSolution(const ExecutionContext& ctx, }; }; } - else if(conv_problem.IsFp16() && log2_gemm_k_global_splits > 0) + else if(problem.IsFp16() && log2_gemm_k_global_splits > 0) { - TensorDescriptor workspaceDesc(miopenFloat, - conv_problem.GetWeights().GetLengths(), - conv_problem.GetWeights().GetStrides()); + TensorDescriptor workspaceDesc( + miopenFloat, problem.GetWeights().GetLengths(), problem.GetWeights().GetStrides()); result.invoker_factory = [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { decltype(auto) wrw_invoke_params = diff --git a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp index 8411992bf2..35020df9b6 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp @@ -42,13 +42,12 @@ namespace solver { //{ 16, 128, 16, 2, 4, 4, 4, 4, 4, 4, 16, 1, 16, 1, 16, 16}, //{ 8, 32, 4, 2, 2, 2, 2, 4, 4, 2, 4, 2, 8, 1, 4, 16} -static inline int -GetImplicitGemmWrwV4R1DynamicGemmkGroups(const conv::ProblemDescription& conv_problem, - const int& GemmKPerBlock) +static inline int GetImplicitGemmWrwV4R1DynamicGemmkGroups(const conv::ProblemDescription& problem, + const int& GemmKPerBlock) { - int n = conv_problem.GetInBatchSize(); - int ho = conv_problem.GetInHeight(); - int wo = conv_problem.GetInWidth(); + int n = problem.GetInBatchSize_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); int gemmk = n * ho * wo; int gemmk_groups = 1; int n_per_group; @@ -70,7 +69,7 @@ GetImplicitGemmWrwV4R1DynamicGemmkGroups(const conv::ProblemDescription& conv_pr } static inline float CallImplicitGemmWrwDynamic(const miopen::Handle& handle, - const conv::ProblemDescription& conv_problem, + const conv::ProblemDescription& problem, ConstData_t src, ConstData_t dst, Data_t wei, @@ -81,21 +80,21 @@ static inline float CallImplicitGemmWrwDynamic(const miopen::Handle& handle, auto kernel = kernels[0]; // clang-format off - int hi = conv_problem.GetOutHeight(); - int wi = conv_problem.GetOutWidth(); - int n = conv_problem.GetOutChannels(); - int k = conv_problem.GetInChannels(); - int c = conv_problem.GetInBatchSize(); - int ho = conv_problem.GetWeightsHeight(); - int wo = conv_problem.GetWeightsWidth(); - int dilation_h = conv_problem.GetInHeight() > 1 ? conv_problem.GetKernelStrideH() : 1; - int dilation_w = conv_problem.GetInWidth() > 1 ? conv_problem.GetKernelStrideW() : 1; - int stride_h = conv_problem.GetWeightsHeight() > 1? conv_problem.GetDilationH() : 1; - int stride_w = conv_problem.GetWeightsWidth() > 1? conv_problem.GetDilationW() : 1; - int pad_h = conv_problem.GetPadH(); - int pad_w = conv_problem.GetPadW(); - int y = conv_problem.GetInHeight(); - int x = conv_problem.GetInWidth(); + int hi = problem.GetOutHeight_(); + int wi = problem.GetOutWidth_(); + int n = problem.GetOutChannels_(); + int k = problem.GetInChannels_(); + int c = problem.GetInBatchSize_(); + int ho = problem.GetWeightsHeight_(); + int wo = problem.GetWeightsWidth_(); + int dilation_h = problem.GetInHeight_() > 1 ? problem.GetKernelStrideH() : 1; + int dilation_w = problem.GetInWidth_() > 1 ? problem.GetKernelStrideW() : 1; + int stride_h = problem.GetWeightsHeight_() > 1? problem.GetDilationH() : 1; + int stride_w = problem.GetWeightsWidth_() > 1? problem.GetDilationW() : 1; + int pad_h = problem.GetPadH(); + int pad_w = problem.GetPadW(); + int y = problem.GetInHeight_(); + int x = problem.GetInWidth_(); int gemmk_groups = 0; int GemmKPerBlock; @@ -104,7 +103,7 @@ static inline float CallImplicitGemmWrwDynamic(const miopen::Handle& handle, else GemmKPerBlock = 4; - gemmk_groups = GetImplicitGemmWrwV4R1DynamicGemmkGroups(conv_problem, GemmKPerBlock); + gemmk_groups = GetImplicitGemmWrwV4R1DynamicGemmkGroups(problem, GemmKPerBlock); MIOPEN_LOG_I2(kernel.GetName() << " with groups for reduction: " << (1 << gemmk_groups) << " GemmKPerBlock: " << GemmKPerBlock); @@ -166,13 +165,13 @@ static inline bool FindImplicitGemmWrwV4R1DynamicKernel(const ProblemDescription int& block_size, int& grid_size) { - int n = problem.GetBatchSize(); - int k = problem.GetInChannels(); - int c = problem.GetOutChannels(); - int ho = problem.GetInHeight(); - int wo = problem.GetInWidth(); - int y = problem.GetWeightsHeight(); - int x = problem.GetWeightsWidth(); + int n = problem.GetBatchSize_(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); int GemmN = c * y * x; int GemmM = k; int GemmK = n * ho * wo; @@ -198,9 +197,8 @@ static inline bool FindImplicitGemmWrwV4R1DynamicKernel(const ProblemDescription if(GemmM % GemmMPerBlock != 0) return false; - int log2_gemmk_groups = - GetImplicitGemmWrwV4R1DynamicGemmkGroups(problem.conv_problem, GemmKPerBlock); - GemmKGroups = 1 << log2_gemmk_groups; + int log2_gemmk_groups = GetImplicitGemmWrwV4R1DynamicGemmkGroups(problem, GemmKPerBlock); + GemmKGroups = 1 << log2_gemmk_groups; if(GemmK % (GemmKGroups * GemmKPerBlock) != 0) return false; @@ -230,9 +228,8 @@ static inline bool FindImplicitGemmWrwV4R1DynamicKernel(const ProblemDescription if(GemmM % GemmMPerBlock != 0) return false; - int log2_gemmk_groups = - GetImplicitGemmWrwV4R1DynamicGemmkGroups(problem.conv_problem, GemmKPerBlock); - GemmKGroups = 1 << log2_gemmk_groups; + int log2_gemmk_groups = GetImplicitGemmWrwV4R1DynamicGemmkGroups(problem, GemmKPerBlock); + GemmKGroups = 1 << log2_gemmk_groups; if(GemmK % (GemmKGroups * GemmKPerBlock) != 0) return false; @@ -251,10 +248,10 @@ static inline bool FindImplicitGemmWrwV4R1DynamicKernel(const ProblemDescription size_t ConvAsmImplicitGemmV4R1DynamicWrw::GetWorkspaceSize(const ExecutionContext&, const ProblemDescription& problem) const { - int k = problem.GetInChannels(); - int c = problem.GetOutChannels(); - int y = problem.GetWeightsHeight(); - int x = problem.GetWeightsWidth(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int y = problem.GetWeightsHeight_(); + int x = problem.GetWeightsWidth_(); int ele_size = 0; int gemmk_groups = 0; int extra_groups = 0; @@ -271,7 +268,7 @@ size_t ConvAsmImplicitGemmV4R1DynamicWrw::GetWorkspaceSize(const ExecutionContex else ele_size = 2; - gemmk_groups = GetImplicitGemmWrwV4R1DynamicGemmkGroups(problem.conv_problem, GemmKPerBlock); + gemmk_groups = GetImplicitGemmWrwV4R1DynamicGemmkGroups(problem, GemmKPerBlock); if(gemmk_groups == 0) extra_groups = 0; @@ -282,17 +279,17 @@ size_t ConvAsmImplicitGemmV4R1DynamicWrw::GetWorkspaceSize(const ExecutionContex static int GetGemmkGroups(const ProblemDescription& problem) { - const auto k = problem.GetInChannels(); - const auto c = problem.GetOutChannels(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int k = problem.GetInChannels_(); + const int c = problem.GetOutChannels_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto GemmN = c * y * x; int GemmKPerBlock = 4; if((k % 128 == 0) && (GemmN % 128 == 0)) GemmKPerBlock = 16; - return GetImplicitGemmWrwV4R1DynamicGemmkGroups(problem.conv_problem, GemmKPerBlock); + return GetImplicitGemmWrwV4R1DynamicGemmkGroups(problem, GemmKPerBlock); } bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx, @@ -391,8 +388,8 @@ ConvSolution ConvAsmImplicitGemmV4R1DynamicWrw::GetSolution(const ExecutionConte kernel_reduction.kernel_name = "wrw_reduction_hip"; kernel_reduction.g_wk.clear(); int block_size_reduction = 256; - int grid_size_redcution = problem.GetOutChannels() * problem.GetInChannels() * - problem.GetWeightsHeight() * problem.GetWeightsWidth() / + int grid_size_redcution = problem.GetOutChannels_() * problem.GetInChannels_() * + problem.GetWeightsHeight_() * problem.GetWeightsWidth_() / (reduction_per_thread * block_size_reduction); kernel_reduction.g_wk.push_back(static_cast(grid_size_redcution) * block_size_reduction); @@ -409,9 +406,7 @@ ConvSolution ConvAsmImplicitGemmV4R1DynamicWrw::GetSolution(const ExecutionConte result.construction_params.push_back(kernel_reduction); } - const auto& conv_problem = problem.conv_problem; - - result.invoker_factory = [conv_problem](const std::vector& kernels) { + result.invoker_factory = [problem](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { decltype(auto) data_ctx = primitive_parameters.CastTo(); const auto& tensors = data_ctx.tensors; @@ -424,7 +419,7 @@ ConvSolution ConvAsmImplicitGemmV4R1DynamicWrw::GetSolution(const ExecutionConte [&](const Kernel& k_wrw) { return handle.Run(k_wrw); }); float elapsed = 0; elapsed = CallImplicitGemmWrwDynamic( - handle, conv_problem, tensors.x, tensors.dy, tensors.dw, workSpace, ks); + handle, problem, tensors.x, tensors.dy, tensors.dw, workSpace, ks); if(handle.IsProfilingEnabled()) { handle.ResetKernelTime(); diff --git a/src/solver/conv_bin_wino3x3U.cpp b/src/solver/conv_bin_wino3x3U.cpp index 38be7c572d..e42aab012c 100644 --- a/src/solver/conv_bin_wino3x3U.cpp +++ b/src/solver/conv_bin_wino3x3U.cpp @@ -72,24 +72,24 @@ bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx, // clang-format off return problem.GetPadW() == 1 && problem.GetPadH() == 1 - && problem.GetWeightsWidth() == 3 - && problem.GetWeightsHeight() == 3 + && problem.GetWeightsWidth_() == 3 + && problem.GetWeightsHeight_() == 3 && problem.GetKernelStrideW() == 1 && problem.GetKernelStrideH() == 1 && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 - && problem.GetBatchSize() < std::pow(2, 16) - && problem.GetInChannels() < std::pow(2, 16) - && problem.GetOutChannels() < std::pow(2, 16) - && problem.GetInHeight() < std::pow(2, 16) - && problem.GetInWidth() < std::pow(2, 16) + && problem.GetBatchSize_() < std::pow(2, 16) + && problem.GetInChannels_() < std::pow(2, 16) + && problem.GetOutChannels_() < std::pow(2, 16) + && problem.GetInHeight_() < std::pow(2, 16) + && problem.GetInWidth_() < std::pow(2, 16) && grid_workgroup_count_x < std::pow(2, 16) - && (problem.GetInChannels() * problem.GetInHeight() * problem.GetInWidth()) <= std::pow(2, 28) - && (problem.GetOutChannels() * problem.GetInHeight() * problem.GetInWidth()) <= std::pow(2, 28) - && (problem.GetInChannels() * problem.GetWeightsWidth() * problem.GetWeightsHeight()) <= std::pow(2, 28) - && (problem.GetOutChannels() * problem.GetWeightsWidth() * problem.GetWeightsHeight()) <= std::pow(2, 28) - && problem.GetInChannels() % 2 == 0 - && problem.GetInChannels() >= (device_is_gfx8 ? 16 : 18) + && (problem.GetInChannels_() * problem.GetInHeight_() * problem.GetInWidth_()) <= std::pow(2, 28) + && (problem.GetOutChannels_() * problem.GetInHeight_() * problem.GetInWidth_()) <= std::pow(2, 28) + && (problem.GetInChannels_() * problem.GetWeightsWidth_() * problem.GetWeightsHeight_()) <= std::pow(2, 28) + && (problem.GetOutChannels_() * problem.GetWeightsWidth_() * problem.GetWeightsHeight_()) <= std::pow(2, 28) + && problem.GetInChannels_() % 2 == 0 + && problem.GetInChannels_() >= (device_is_gfx8 ? 16 : 18) && problem.IsFp32() && problem.GetGroupCount() == 1 && problem.GetInLayout() == "NCHW"; diff --git a/src/solver/conv_bin_winoRxS.cpp b/src/solver/conv_bin_winoRxS.cpp index 1d2660cd0e..8b42bf5899 100644 --- a/src/solver/conv_bin_winoRxS.cpp +++ b/src/solver/conv_bin_winoRxS.cpp @@ -282,17 +282,17 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, { return IsShaderContraintsMet(ctx, problem, - problem.GetInHeight(), - problem.GetInWidth(), + problem.GetInHeight_(), + problem.GetInWidth_(), problem.GetDilationH(), problem.GetDilationW(), - problem.GetBatchSize(), // N - problem.GetInChannels(), // K - problem.GetOutHeight(), - problem.GetOutWidth(), - problem.GetWeightsHeight(), - problem.GetWeightsWidth(), - problem.GetOutChannels(), // C + problem.GetBatchSize_(), // N + problem.GetInChannels_(), // K + problem.GetOutHeight_(), + problem.GetOutWidth_(), + problem.GetWeightsHeight_(), + problem.GetWeightsWidth_(), + problem.GetOutChannels_(), // C fp16, 2); } @@ -300,17 +300,17 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, { return IsShaderContraintsMet(ctx, problem, - problem.GetWeightsHeight(), // RxS - problem.GetWeightsWidth(), + problem.GetWeightsHeight_(), // RxS + problem.GetWeightsWidth_(), problem.GetKernelStrideH(), problem.GetKernelStrideW(), - problem.GetInChannels(), // C - problem.GetOutChannels(), // K - problem.GetInHeight(), // HxW - problem.GetInWidth(), - problem.GetOutHeight(), // OHxOW - problem.GetOutWidth(), - problem.GetBatchSize(), // N + problem.GetInChannels_(), // C + problem.GetOutChannels_(), // K + problem.GetInHeight_(), // HxW + problem.GetInWidth_(), + problem.GetOutHeight_(), // OHxOW + problem.GetOutWidth_(), + problem.GetBatchSize_(), // N fp16, 3); } diff --git a/src/solver/conv_bin_winoRxS_fused.cpp b/src/solver/conv_bin_winoRxS_fused.cpp index 03982a841c..f184b0e291 100644 --- a/src/solver/conv_bin_winoRxS_fused.cpp +++ b/src/solver/conv_bin_winoRxS_fused.cpp @@ -72,18 +72,18 @@ bool ConvBinWinogradRxSFused::IsApplicable(const FusionContext& context, if(name != "gfx803") return false; - const auto W = conv_problem.conv_problem.GetInWidth(); - const auto H = conv_problem.conv_problem.GetInHeight(); - const auto C = conv_problem.conv_problem.GetInChannels(); - const auto N = conv_problem.conv_problem.GetInBatchSize(); - const auto K = conv_problem.conv_problem.GetOutChannels(); - const auto y = conv_problem.conv_problem.GetWeightsHeight(); - const auto x = conv_problem.conv_problem.GetWeightsWidth(); - const auto OH = conv_problem.conv_problem.GetOutHeight(); - const auto OW = conv_problem.conv_problem.GetOutWidth(); - const auto pad_h = conv_problem.conv_problem.GetPadH(); - const auto pad_w = conv_problem.conv_problem.GetPadW(); - const auto group_count = conv_problem.conv_problem.GetGroupCount(); + const auto W = conv_problem.GetInWidth_(); + const auto H = conv_problem.GetInHeight_(); + const auto C = conv_problem.GetInChannels_(); + const auto N = conv_problem.GetInBatchSize_(); + const auto K = conv_problem.GetOutChannels_(); + const auto y = conv_problem.GetWeightsHeight_(); + const auto x = conv_problem.GetWeightsWidth_(); + const auto OH = conv_problem.GetOutHeight_(); + const auto OW = conv_problem.GetOutWidth_(); + const auto pad_h = conv_problem.GetPadH(); + const auto pad_w = conv_problem.GetPadW(); + const auto group_count = conv_problem.GetGroupCount(); size_t padded_y = 0; size_t padded_x = 0; @@ -120,10 +120,10 @@ bool ConvBinWinogradRxSFused::IsApplicable(const FusionContext& context, return conv_problem.GetKernelStrideH() == conv_problem.GetKernelStrideW() && conv_problem.GetDilationH() == 1 && conv_problem.GetDilationW() == 1 - && (C * x * y) <= std::pow(2, 28) - && (K * x * y) <= std::pow(2, 28) - && (K * OH * OW) <= std::pow(2, 28) - && (C * H * W) <= std::pow(2, 28) + && (static_cast(C) * x * y) <= std::pow(2, 28) + && (static_cast(K) * x * y) <= std::pow(2, 28) + && (static_cast(K) * OH * OW) <= std::pow(2, 28) + && (static_cast(C) * H * W) <= std::pow(2, 28) && y <= std::pow(2, 16) && x <= std::pow(2, 16) && pad_h <= std::pow(2, 16) @@ -160,8 +160,8 @@ ConvSolution ConvBinWinogradRxSFused::GetSolution(const FusionContext& context, {"ROCM_METADATA_VERSION", conv_ctx.rmv.UseV3() ? 5 : 4}, }; kernel.comp_options = options.GenerateFor(kbp::GcnAsm{}); - const auto x = conv_problem.conv_problem.GetWeightsWidth(); - const auto y = conv_problem.conv_problem.GetWeightsHeight(); + const auto x = conv_problem.GetWeightsWidth_(); + const auto y = conv_problem.GetWeightsHeight_(); kernel.kernel_name = "miopenSp3AsmConvRxSU_CBA"; if(conv_problem.GetKernelStrideH() == 1) diff --git a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp index 96504bef22..cdb4225b88 100644 --- a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp +++ b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp @@ -278,10 +278,10 @@ void PerformanceConfigConvCKIgemmFwdBiasActivFused::HeuristicInit( #if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL std::ignore = fdesc_problem; #else - const auto& conv_prob = fdesc_problem.GetConvProblem(0, conv::Direction::Forward).conv_problem; - switch(conv_prob.GetInDataType()) + const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); + switch(conv_problem.GetInDataType()) { - case miopenHalf: Init(conv_prob); break; + case miopenHalf: Init(conv_problem); break; case miopenInt8: case miopenFloat: case miopenInt32: @@ -331,10 +331,10 @@ bool PerformanceConfigConvCKIgemmFwdBiasActivFused::IsValid( return false; #else // Extract convolution problem from the fusion context. - const auto& problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); - switch(problem.conv_problem.GetInDataType()) + const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); + switch(conv_problem.GetInDataType()) { - case miopenHalf: return CheckIsSupportCKArgs(problem); + case miopenHalf: return CheckIsSupportCKArgs(conv_problem); case miopenInt8: case miopenFloat: case miopenInt32: @@ -405,23 +405,23 @@ bool ConvCKIgemmFwdBiasActivFused::IsApplicable(const FusionContext& ctx, const auto& activ_op = dynamic_cast(*desc.op_map[2]); if(activ_op.activMode != miopenActivationRELU) return false; - const auto& problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); - if(problem.conv_problem.GetConv().attribute.deterministic) + const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); + if(conv_problem.GetConv().attribute.deterministic) return false; - if(problem.conv_problem.GetInDataType() != problem.conv_problem.GetWeightsDataType() || - problem.conv_problem.GetInDataType() != problem.conv_problem.GetOutDataType()) + if(conv_problem.GetInDataType() != conv_problem.GetWeightsDataType() || + conv_problem.GetInDataType() != conv_problem.GetOutDataType()) return false; - if(!problem.Is2d()) + if(!conv_problem.Is2d()) return false; const std::string arch = ctx.GetStream().GetDeviceName(); if(arch != "gfx908" && arch != "gfx90a") return false; - if(!problem.IsLayoutNHWC()) + if(!conv_problem.IsLayoutNHWC()) return false; - switch(problem.conv_problem.GetInDataType()) + switch(conv_problem.GetInDataType()) { - case miopenHalf: return CheckCKApplicability(problem); + case miopenHalf: return CheckCKApplicability(conv_problem); case miopenInt8: case miopenFloat: case miopenInt32: @@ -444,15 +444,15 @@ ConvSolution ConvCKIgemmFwdBiasActivFused::GetSolution( std::ignore = config; return {}; #else - const auto& problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); + const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); ConvSolution result; result.invoker_factory = [=](const std::vector& kernels) { std::ignore = kernels; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - switch(problem.conv_problem.GetInDataType()) + switch(conv_problem.GetInDataType()) { case miopenHalf: - RunCKSolution(handle, primitive_parameters, problem, config); + RunCKSolution(handle, primitive_parameters, conv_problem, config); break; case miopenInt8: case miopenFloat: diff --git a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp index 83ed3efcc5..0c743ee9d7 100644 --- a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp +++ b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp @@ -108,7 +108,7 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ConvolutionContext& ctx, if(problem.GetGroupCount() != 1) return false; if(ctx.GetStream().GetTargetProperties().Name() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + problem.IsGfx90aFp16altRequired()) return false; if(!IsIndexRangeLargeEnough(problem)) return false; diff --git a/src/solver/conv_direct_naive_conv_bwd.cpp b/src/solver/conv_direct_naive_conv_bwd.cpp index a58e259386..077fe550bc 100644 --- a/src/solver/conv_direct_naive_conv_bwd.cpp +++ b/src/solver/conv_direct_naive_conv_bwd.cpp @@ -61,27 +61,27 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ConvolutionContext& ctx, { ConvSolution result; - int di = problem.GetOutDepth(); - int hi = problem.GetOutHeight(); - int wi = problem.GetOutWidth(); - int n = problem.GetBatchSize(); - int k = problem.GetInChannels(); - int c = problem.GetOutChannels(); - int do_ = problem.GetInDepth(); - int ho = problem.GetInHeight(); - int wo = problem.GetInWidth(); - int sz = problem.GetInDepth() > 1 ? problem.GetKernelStrideD() : 1; - int sy = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1; - int sx = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1; - int dz = problem.GetWeightsDepth() > 1 ? problem.GetDilationD() : 1; - int dy = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - int dx = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + int di = problem.GetOutDepth_(); + int hi = problem.GetOutHeight_(); + int wi = problem.GetOutWidth_(); + int n = problem.GetBatchSize_(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int do_ = problem.GetInDepth_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); + int sz = problem.GetInDepth_() > 1 ? problem.GetKernelStrideD() : 1; + int sy = problem.GetInHeight_() > 1 ? problem.GetKernelStrideH() : 1; + int sx = problem.GetInWidth_() > 1 ? problem.GetKernelStrideW() : 1; + int dz = problem.GetWeightsDepth_() > 1 ? problem.GetDilationD() : 1; + int dy = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + int dx = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; int pz = problem.GetPadD(); int py = problem.GetPadH(); int px = problem.GetPadW(); - int fz = problem.GetWeightsDepth(); - int fy = problem.GetWeightsHeight(); - int fx = problem.GetWeightsWidth(); + int fz = problem.GetWeightsDepth_(); + int fy = problem.GetWeightsHeight_(); + int fx = problem.GetWeightsWidth_(); int group = problem.GetGroupCount(); int c_per_group = c / group; int k_per_group = k / group; diff --git a/src/solver/conv_direct_naive_conv_fwd.cpp b/src/solver/conv_direct_naive_conv_fwd.cpp index 74626ee3db..b3b2da870c 100644 --- a/src/solver/conv_direct_naive_conv_fwd.cpp +++ b/src/solver/conv_direct_naive_conv_fwd.cpp @@ -61,15 +61,15 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ConvolutionContext& ctx, { ConvSolution result; - int di = problem.GetInDepth(); - int hi = problem.GetInHeight(); - int wi = problem.GetInWidth(); - int n = problem.GetBatchSize(); - int k = problem.GetOutChannels(); - int c = problem.GetInChannels(); - int do_ = problem.GetOutDepth(); - int ho = problem.GetOutHeight(); - int wo = problem.GetOutWidth(); + int di = problem.GetInDepth_(); + int hi = problem.GetInHeight_(); + int wi = problem.GetInWidth_(); + int n = problem.GetBatchSize_(); + int k = problem.GetOutChannels_(); + int c = problem.GetInChannels_(); + int do_ = problem.GetOutDepth_(); + int ho = problem.GetOutHeight_(); + int wo = problem.GetOutWidth_(); int sz = problem.GetKernelStrideD(); int sy = problem.GetKernelStrideH(); int sx = problem.GetKernelStrideW(); @@ -79,9 +79,9 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ConvolutionContext& ctx, int pz = problem.GetPadD(); int py = problem.GetPadH(); int px = problem.GetPadW(); - int fz = problem.GetWeightsDepth(); - int fy = problem.GetWeightsHeight(); - int fx = problem.GetWeightsWidth(); + int fz = problem.GetWeightsDepth_(); + int fy = problem.GetWeightsHeight_(); + int fx = problem.GetWeightsWidth_(); int group = problem.GetGroupCount(); int c_per_group = c / group; int k_per_group = k / group; diff --git a/src/solver/conv_direct_naive_conv_wrw.cpp b/src/solver/conv_direct_naive_conv_wrw.cpp index 911c02cda5..f25d3a3baa 100644 --- a/src/solver/conv_direct_naive_conv_wrw.cpp +++ b/src/solver/conv_direct_naive_conv_wrw.cpp @@ -61,27 +61,27 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ConvolutionContext& ctx, { ConvSolution result; - int di = problem.GetOutDepth(); - int hi = problem.GetOutHeight(); - int wi = problem.GetOutWidth(); - int n = problem.GetBatchSize(); - int k = problem.GetInChannels(); - int c = problem.GetOutChannels(); - int do_ = problem.GetInDepth(); - int ho = problem.GetInHeight(); - int wo = problem.GetInWidth(); - int sz = problem.GetInDepth() > 1 ? problem.GetKernelStrideD() : 1; - int sy = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1; - int sx = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1; - int dz = problem.GetWeightsDepth() > 1 ? problem.GetDilationD() : 1; - int dy = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; - int dx = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; + int di = problem.GetOutDepth_(); + int hi = problem.GetOutHeight_(); + int wi = problem.GetOutWidth_(); + int n = problem.GetBatchSize_(); + int k = problem.GetInChannels_(); + int c = problem.GetOutChannels_(); + int do_ = problem.GetInDepth_(); + int ho = problem.GetInHeight_(); + int wo = problem.GetInWidth_(); + int sz = problem.GetInDepth_() > 1 ? problem.GetKernelStrideD() : 1; + int sy = problem.GetInHeight_() > 1 ? problem.GetKernelStrideH() : 1; + int sx = problem.GetInWidth_() > 1 ? problem.GetKernelStrideW() : 1; + int dz = problem.GetWeightsDepth_() > 1 ? problem.GetDilationD() : 1; + int dy = problem.GetWeightsHeight_() > 1 ? problem.GetDilationH() : 1; + int dx = problem.GetWeightsWidth_() > 1 ? problem.GetDilationW() : 1; int pz = problem.GetPadD(); int py = problem.GetPadH(); int px = problem.GetPadW(); - int fz = problem.GetWeightsDepth(); - int fy = problem.GetWeightsHeight(); - int fx = problem.GetWeightsWidth(); + int fz = problem.GetWeightsDepth_(); + int fy = problem.GetWeightsHeight_(); + int fx = problem.GetWeightsWidth_(); int group = problem.GetGroupCount(); int c_per_group = c / group; int k_per_group = k / group; diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp index b4727c1081..5127f326e9 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp @@ -297,7 +297,7 @@ MakeCK3DGroupFwdInvokerFactory(const miopen::ProblemDescription& problem, const PerformanceConfigHipImplicitGemm3DGroupFwdXdlops& config) { auto args = CKArgs{problem}; - miopenDataType_t data_type = problem.conv_problem.GetInDataType(); + miopenDataType_t data_type = problem.GetInDataType(); auto kernel_id = config.kernel_id; return [args, data_type, kernel_id](const std::vector& kernels) { @@ -335,7 +335,7 @@ void PerformanceConfigHipImplicitGemm3DGroupFwdXdlops::HeuristicInit( #if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL std::ignore = problem; #else - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: Init(problem); break; case miopenFloat: Init(problem); break; @@ -379,7 +379,7 @@ bool PerformanceConfigHipImplicitGemm3DGroupFwdXdlops::IsValid( std::ignore = problem; return false; #else - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: return CheckIsSupportCKArgs(problem); case miopenFloat: return CheckIsSupportCKArgs(problem); @@ -434,11 +434,11 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable(const ConvolutionContext& #else if(miopen::IsDisabled(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; - if(problem.conv_problem.GetInDataType() != problem.conv_problem.GetWeightsDataType() || - problem.conv_problem.GetWeightsDataType() != problem.conv_problem.GetOutDataType() || - problem.conv_problem.GetInDataType() != problem.conv_problem.GetOutDataType()) + if(problem.GetInDataType() != problem.GetWeightsDataType() || + problem.GetWeightsDataType() != problem.GetOutDataType() || + problem.GetInDataType() != problem.GetOutDataType()) return false; if(!problem.direction.IsForward()) return false; @@ -449,7 +449,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable(const ConvolutionContext& const std::string& arch = ctx.GetStream().GetDeviceName(); if(!(arch == "gfx908" || arch == "gfx90a")) return false; - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: return CheckCKApplicability(problem); case miopenFloat: return CheckCKApplicability(problem); diff --git a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp index 469b0f39d1..9566d17331 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp @@ -256,7 +256,7 @@ void PerformanceConfigHipImplicitGemmBwdXdlops::HeuristicInit(const ProblemDescr #if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL std::ignore = problem; #else - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: Init(problem); break; case miopenFloat: Init(problem); break; @@ -298,7 +298,7 @@ bool PerformanceConfigHipImplicitGemmBwdXdlops::IsValid(const ProblemDescription std::ignore = problem; return false; #else - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: return CheckIsSupportCKArgs(problem); case miopenFloat: return CheckIsSupportCKArgs(problem); @@ -353,11 +353,11 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable(const ConvolutionContext& ctx, #else if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_XDLOPS{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; - if(problem.conv_problem.GetInDataType() != problem.conv_problem.GetWeightsDataType() || - problem.conv_problem.GetWeightsDataType() != problem.conv_problem.GetOutDataType() || - problem.conv_problem.GetInDataType() != problem.conv_problem.GetOutDataType()) + if(problem.GetInDataType() != problem.GetWeightsDataType() || + problem.GetWeightsDataType() != problem.GetOutDataType() || + problem.GetInDataType() != problem.GetOutDataType()) return false; if(!problem.direction.IsBackwardData()) return false; @@ -370,13 +370,13 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable(const ConvolutionContext& ctx, if(!IsComposableKernelSupportedHardware(ctx)) return false; const std::string& arch = ctx.GetStream().GetDeviceName(); - if(arch == "gfx90a" && problem.conv_problem.IsGfx90aFp16altRequired()) + if(arch == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; if(!IsIndexRangeLargeEnough(problem)) return false; if(problem.GetGroupCount() > 1) return false; - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: return CheckCKApplicability(problem); case miopenFloat: return CheckCKApplicability(problem); @@ -405,7 +405,7 @@ ConvSolution ConvHipImplicitGemmBwdXdlops::GetSolution( result.invoker_factory = [=](const std::vector& kernels) { std::ignore = kernels; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: RunCKSolution(handle, primitive_parameters, problem, config); diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp index c6aa6e0451..ec3ea2fc65 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp @@ -634,7 +634,7 @@ bool ConvHipImplicitGemmBwdDataV1R1::IsApplicable(const ConvolutionContext& ctx, return false; if(ThisSolverIsDeprecatedStatic::IsDisabled(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!ctx.use_hip_kernels) return false; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp index ac339d28f4..1075556d39 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp @@ -763,7 +763,7 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ConvolutionContext if(ThisSolverIsDeprecatedStatic::IsDisabled(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!IsComposableKernelSupportedHardware(ctx)) @@ -784,8 +784,7 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ConvolutionContext if(!problem.Is2d()) return false; - if(ctx.GetStream().GetDeviceName() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; if(!IsIndexRangeLargeEnough(problem)) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp index 621d752a78..79250a6b61 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp @@ -736,7 +736,7 @@ bool ConvHipImplicitGemmBwdDataV4R1::IsApplicable(const ConvolutionContext& ctx, if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp index 993e3fd7b9..8a147585ab 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp @@ -828,7 +828,7 @@ bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsApplicable(const ConvolutionContext return false; if(ThisSolverIsDeprecatedStatic::IsDisabled(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; @@ -846,8 +846,7 @@ bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsApplicable(const ConvolutionContext return false; if(!problem.IsLayoutDefault()) return false; - if(ctx.GetStream().GetDeviceName() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; bool is_applicable = true; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp index 2b85f86147..99edfd139d 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp @@ -49,7 +49,7 @@ bool ConvHipImplicitGemmV4R1Fwd::IsApplicable(const ConvolutionContext& ctx, return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!problem.direction.IsForward()) return false; @@ -63,17 +63,16 @@ bool ConvHipImplicitGemmV4R1Fwd::IsApplicable(const ConvolutionContext& ctx, return false; if(!problem.IsLayoutDefault()) return false; - if(ctx.GetStream().GetDeviceName() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; - std::size_t n = problem.GetBatchSize(); - std::size_t k = problem.GetOutChannels() / problem.GetGroupCount(); - std::size_t c = problem.GetInChannels() / problem.GetGroupCount(); - std::size_t y = problem.GetWeightsHeight(); - std::size_t x = problem.GetWeightsWidth(); - std::size_t ho = problem.GetOutHeight(); - std::size_t wo = problem.GetOutWidth(); + std::size_t n = problem.GetBatchSize_(); + std::size_t k = problem.GetOutChannels_() / problem.GetGroupCount(); + std::size_t c = problem.GetInChannels_() / problem.GetGroupCount(); + std::size_t y = problem.GetWeightsHeight_(); + std::size_t x = problem.GetWeightsWidth_(); + std::size_t ho = problem.GetOutHeight_(); + std::size_t wo = problem.GetOutWidth_(); std::size_t eMultiple = (problem.IsFp16() || problem.IsBfp16()) ? 16 : 8; // batch is divided by epack to pack 2/4 fp16/bfp16 @@ -103,20 +102,19 @@ bool ConvHipImplicitGemmV4R1WrW::IsApplicable(const ConvolutionContext& ctx, return false; if(!problem.IsLayoutDefault()) return false; - if(ctx.GetStream().GetDeviceName() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; - // retrieve dimension from ConvolutionContext - // remember: ConvolutionContext has swapped some dimensions for you! + // retrieve dimension from ProblemDescription + // remember: ProblemDescription has swapped some dimensions for you! // undo the swap to avoid confusion - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels() / problem.GetGroupCount(); // unswap - const auto c = problem.GetOutChannels() / problem.GetGroupCount(); // unswap - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); - const auto ho = problem.GetInHeight(); // unswap - const auto wo = problem.GetInWidth(); // unswap + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_() / problem.GetGroupCount(); // unswap + const int c = problem.GetOutChannels_() / problem.GetGroupCount(); // unswap + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); + const int ho = problem.GetInHeight_(); // unswap + const int wo = problem.GetInWidth_(); // unswap // equivalent dimension for bwd-wrw std::size_t n_eqv = c; @@ -196,15 +194,15 @@ ConvHipImplicitGemmV4R1Fwd::GetSolution(const ConvolutionContext& ctx, const auto& N2 = config.GemmNPerThreadSubC; // retrieve dimension from ProblemDescription - const auto n = problem.GetBatchSize(); - const auto k = problem.GetOutChannels(); - const auto c = problem.GetInChannels(); - const auto hi = problem.GetInHeight(); - const auto wi = problem.GetInWidth(); - const auto ho = problem.GetOutHeight(); - const auto wo = problem.GetOutWidth(); - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetOutChannels_(); + const int c = problem.GetInChannels_(); + const int hi = problem.GetInHeight_(); + const int wi = problem.GetInWidth_(); + const int ho = problem.GetOutHeight_(); + const int wo = problem.GetOutWidth_(); + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto conv_stride_h = problem.GetKernelStrideH(); const auto conv_stride_w = problem.GetKernelStrideW(); const auto conv_dilation_h = problem.GetDilationH(); @@ -300,7 +298,7 @@ ConvHipImplicitGemmV4R1Fwd::GetSolution(const ConvolutionContext& ctx, // Borrowed from non-padded version of v4 InBlockCopySrcDataPerRead_B = - problem.GetWeightsWidth() > 1 + problem.GetWeightsWidth_() > 1 ? std::min(InBlockCopySrcDataPerRead_B, GetReadWriteVectorSize(problem.GetDilationW())) : InBlockCopySrcDataPerRead_B; InBlockCopySrcDataPerRead_B = problem.GetKernelStrideW() > 1 ? 1 : InBlockCopySrcDataPerRead_B; @@ -401,15 +399,15 @@ ConvHipImplicitGemmV4R1WrW::GetSolution(const ConvolutionContext& ctx, // retrieve dimension from ProblemDescription // remember: ProblemDescription has swapped some dimensions for you! // undo the swap to avoid confusion - const auto n = problem.GetBatchSize(); - const auto k = problem.GetInChannels(); // unswap - const auto c = problem.GetOutChannels(); // unswap - const auto hi = problem.GetOutHeight(); // unswap - const auto wi = problem.GetOutWidth(); // unswap - const auto ho = problem.GetInHeight(); // unswap - const auto wo = problem.GetInWidth(); // unswap - const auto y = problem.GetWeightsHeight(); - const auto x = problem.GetWeightsWidth(); + const int n = problem.GetBatchSize_(); + const int k = problem.GetInChannels_(); // unswap + const int c = problem.GetOutChannels_(); // unswap + const int hi = problem.GetOutHeight_(); // unswap + const int wi = problem.GetOutWidth_(); // unswap + const int ho = problem.GetInHeight_(); // unswap + const int wo = problem.GetInWidth_(); // unswap + const int y = problem.GetWeightsHeight_(); + const int x = problem.GetWeightsWidth_(); const auto conv_stride_h = problem.GetKernelStrideH(); const auto conv_stride_w = problem.GetKernelStrideW(); const auto conv_dilation_h = problem.GetDilationH(); @@ -515,7 +513,7 @@ ConvHipImplicitGemmV4R1WrW::GetSolution(const ConvolutionContext& ctx, // clang-format off // Borrowed from non-padded version of v4 InBlockCopySrcDataPerRead_B = - problem.GetWeightsWidth() > 1 + problem.GetWeightsWidth_() > 1 ? std::min(InBlockCopySrcDataPerRead_B, GetReadWriteVectorSize(problem.GetDilationW())) : InBlockCopySrcDataPerRead_B; InBlockCopySrcDataPerRead_B = problem.GetKernelStrideW() > 1 ? 1 : InBlockCopySrcDataPerRead_B; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp index 5d92ba3707..15f247e7d3 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp @@ -579,7 +579,7 @@ bool ConvHipImplicitGemmV4R4Fwd::IsApplicable(const ConvolutionContext& ctx, return false; if(ThisSolverIsDeprecatedStatic::IsDisabled(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!ctx.use_hip_kernels) return false; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp index 3b6f8a1d23..edcce82e68 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp @@ -975,7 +975,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ConvolutionContext if(ThisSolverIsDeprecatedStatic::IsDisabled(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!ctx.use_hip_kernels) @@ -996,8 +996,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ConvolutionContext if(!problem.Is2d()) return false; - if(ctx.GetStream().GetDeviceName() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; if(!IsIndexRangeLargeEnough(problem)) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp index 4dac9f573b..95fed60757 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp @@ -1041,7 +1041,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( if(ThisSolverIsDeprecatedStatic::IsDisabled(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!ctx.use_hip_kernels) @@ -1062,8 +1062,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( if(!problem.Is2d()) return false; - if(ctx.GetStream().GetDeviceName() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; if(!IsIndexRangeLargeEnough(problem)) @@ -1095,12 +1094,12 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( #if WORKAROUND_MI100_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4_PADDED_GEMM_XDLOPS if(ctx.GetStream().GetDeviceName() == "gfx908" && problem.IsFp32()) { - if((problem.GetInChannels() == 3 && problem.GetOutChannels() == 1 && - problem.GetInWidth() == 227 && problem.GetInHeight() == 227 && - problem.GetWeightsWidth() == 3 && problem.GetWeightsHeight() == 3) // - || (problem.GetInChannels() == 64 && problem.GetOutChannels() == 1 && - problem.GetInWidth() == 112 && problem.GetInHeight() == 112 && - problem.GetWeightsWidth() == 3 && problem.GetWeightsHeight() == 3 && + if((problem.GetInChannels_() == 3 && problem.GetOutChannels_() == 1 && + problem.GetInWidth_() == 227 && problem.GetInHeight_() == 227 && + problem.GetWeightsWidth_() == 3 && problem.GetWeightsHeight_() == 3) // + || (problem.GetInChannels_() == 64 && problem.GetOutChannels_() == 1 && + problem.GetInWidth_() == 112 && problem.GetInHeight_() == 112 && + problem.GetWeightsWidth_() == 3 && problem.GetWeightsHeight_() == 3 && problem.GetKernelStrideW() >= 2 && problem.GetKernelStrideH() >= 2 && problem.GetDilationW() >= 3 && problem.GetDilationH() >= 3)) { diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp index 0745dc60c7..71f97ff3b2 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp @@ -1005,7 +1005,7 @@ bool ConvHipImplicitGemmForwardV4R5Xdlops::IsApplicable(const ConvolutionContext if(ThisSolverIsDeprecatedStatic::IsDisabled(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!ctx.use_hip_kernels) @@ -1033,8 +1033,7 @@ bool ConvHipImplicitGemmForwardV4R5Xdlops::IsApplicable(const ConvolutionContext if(!problem.Is2d()) return false; - if(ctx.GetStream().GetDeviceName() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; if(!IsIndexRangeLargeEnough(problem)) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp index cfe5851eb6..5bd2e24cbb 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -254,7 +254,7 @@ MakeInvokerFactoryHipImplGemmFwdXdlops(const ProblemDescription& problem, auto ck_args = CKArgs{problem}; const auto config_idx = config.index; - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenInt8: return MakeInvokerFactoryHelper(std::move(ck_args), config_idx); case miopenHalf: return MakeInvokerFactoryHelper(std::move(ck_args), config_idx); @@ -281,7 +281,7 @@ void PerformanceConfigHipImplicitGemmFwdXdlops::HeuristicInit(const ProblemDescr this->index = 0; this->total_size = 0; this->kernel_id = ""; - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenInt8: Init(problem); break; case miopenHalf: Init(problem); break; @@ -316,7 +316,7 @@ bool PerformanceConfigHipImplicitGemmFwdXdlops::IsValid(const ProblemDescription std::ignore = problem; return false; #else - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenInt8: return CheckIsSupportCKArgs(problem); case miopenHalf: return CheckIsSupportCKArgs(problem); @@ -371,11 +371,11 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable(const ConvolutionContext& ctx, #else if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; - if(problem.conv_problem.GetInDataType() != problem.conv_problem.GetWeightsDataType() || - problem.conv_problem.GetWeightsDataType() != problem.conv_problem.GetOutDataType() || - problem.conv_problem.GetInDataType() != problem.conv_problem.GetOutDataType()) + if(problem.GetInDataType() != problem.GetWeightsDataType() || + problem.GetWeightsDataType() != problem.GetOutDataType() || + problem.GetInDataType() != problem.GetOutDataType()) return false; if(!problem.direction.IsForward()) return false; @@ -386,7 +386,7 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable(const ConvolutionContext& ctx, if(!IsComposableKernelSupportedHardware(ctx)) return false; const std::string& arch = ctx.GetStream().GetDeviceName(); - if(arch == "gfx90a" && problem.conv_problem.IsGfx90aFp16altRequired()) + if(arch == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; if(!IsIndexRangeLargeEnough(problem)) return false; @@ -394,7 +394,7 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable(const ConvolutionContext& ctx, return false; if(problem.GetGroupCount() > 1) return false; - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenInt8: return CheckCKApplicability(problem); case miopenHalf: return CheckCKApplicability(problem); diff --git a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp index 1031bd175c..a43b476bee 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp @@ -299,7 +299,7 @@ void PerformanceConfigHipImplicitGemmGroupFwdXdlops::HeuristicInit( #if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL std::ignore = problem; #else - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: Init(problem); break; case miopenFloat: Init(problem); break; @@ -342,7 +342,7 @@ bool PerformanceConfigHipImplicitGemmGroupFwdXdlops::IsValid( std::ignore = problem; return false; #else - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: return CheckIsSupportCKArgs(problem); case miopenFloat: return CheckIsSupportCKArgs(problem); @@ -397,11 +397,11 @@ bool ConvHipImplicitGemmGroupFwdXdlops::IsApplicable(const ConvolutionContext& c #else if(miopen::IsDisabled(MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; - if(problem.conv_problem.GetInDataType() != problem.conv_problem.GetWeightsDataType() || - problem.conv_problem.GetWeightsDataType() != problem.conv_problem.GetOutDataType() || - problem.conv_problem.GetInDataType() != problem.conv_problem.GetOutDataType()) + if(problem.GetInDataType() != problem.GetWeightsDataType() || + problem.GetWeightsDataType() != problem.GetOutDataType() || + problem.GetInDataType() != problem.GetOutDataType()) return false; if(!problem.direction.IsForward()) return false; @@ -412,7 +412,7 @@ bool ConvHipImplicitGemmGroupFwdXdlops::IsApplicable(const ConvolutionContext& c const std::string& arch = ctx.GetStream().GetDeviceName(); if(!(arch == "gfx908" || arch == "gfx90a")) return false; - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: return CheckCKApplicability(problem); case miopenFloat: return CheckCKApplicability(problem); @@ -441,7 +441,7 @@ ConvSolution ConvHipImplicitGemmGroupFwdXdlops::GetSolution( result.invoker_factory = [=](const std::vector& kernels) { std::ignore = kernels; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { - switch(problem.conv_problem.GetInDataType()) + switch(problem.GetInDataType()) { case miopenHalf: RunCKSolution(handle, primitive_parameters, problem, config); diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp index 8f03f4d098..331e8a14c2 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp @@ -582,7 +582,7 @@ bool ConvHipImplicitGemmV4R4WrW::IsApplicable(const ConvolutionContext& ctx, return false; if(ThisSolverIsDeprecatedStatic::IsDisabled(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!ctx.use_hip_kernels) return false; diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp index fb7cc5f4f4..f6c4847551 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp @@ -981,7 +981,7 @@ ConvSolution ConvHipImplicitGemmWrwV4R4Xdlops::GetSolution( result.construction_params.push_back(construction_parameters); - const auto& conv = problem.conv_problem.GetConv(); + const auto& conv = problem.GetConv(); const auto& lowp_quant = conv.lowp_quant; result.invoker_factory = [=](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { @@ -1045,7 +1045,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops::IsApplicable(const ConvolutionContext& ct if(ThisSolverIsDeprecatedStatic::IsDisabled(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!ctx.use_hip_kernels) @@ -1066,8 +1066,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops::IsApplicable(const ConvolutionContext& ct if(!problem.Is2d()) return false; - if(ctx.GetStream().GetDeviceName() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; if(!IsIndexRangeLargeEnough(problem)) diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp index 6f39f6b2d4..9f46af0245 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp @@ -1047,7 +1047,7 @@ ConvSolution ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::GetSolution( result.construction_params.push_back(construction_parameters); - const auto& conv = problem.conv_problem.GetConv(); + const auto& conv = problem.GetConv(); const auto& lowp_quant = conv.lowp_quant; result.invoker_factory = [=](const std::vector& kernels) { return [=](const Handle& handle, const AnyInvokeParams& primitive_params) { @@ -1114,7 +1114,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( if(!IsComposableKernelSupportedHardware(ctx)) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!IsXdlopsSupport(ctx)) @@ -1132,8 +1132,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( if(!problem.Is2d()) return false; - if(ctx.GetStream().GetDeviceName() == "gfx90a" && - problem.conv_problem.IsGfx90aFp16altRequired()) + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; if(!IsIndexRangeLargeEnough(problem)) diff --git a/src/solver/conv_mlir_igemm_bwd.cpp b/src/solver/conv_mlir_igemm_bwd.cpp index c9054a62c7..6fa2b2e7f4 100644 --- a/src/solver/conv_mlir_igemm_bwd.cpp +++ b/src/solver/conv_mlir_igemm_bwd.cpp @@ -43,7 +43,7 @@ bool ConvMlirIgemmBwd::IsApplicable(const ConvolutionContext& ctx, #if MIOPEN_USE_MLIR if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_MLIR_IGEMM_BWD{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!problem.direction.IsBackwardData()) return false; diff --git a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp index 9a7fa81650..c55d89464a 100644 --- a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp @@ -44,7 +44,7 @@ bool ConvMlirIgemmBwdXdlops::IsApplicable(const ConvolutionContext& ctx, #if MIOPEN_USE_MLIR if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_MLIR_IGEMM_BWD_XDLOPS{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!IsXdlopsSupport(ctx)) return false; diff --git a/src/solver/conv_mlir_igemm_fwd.cpp b/src/solver/conv_mlir_igemm_fwd.cpp index 86314150c8..3a1eb3068d 100644 --- a/src/solver/conv_mlir_igemm_fwd.cpp +++ b/src/solver/conv_mlir_igemm_fwd.cpp @@ -163,7 +163,7 @@ bool ConvMlirIgemmFwd::IsApplicable(const ConvolutionContext& ctx, #if MIOPEN_USE_MLIR if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_MLIR_IGEMM_FWD{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!problem.direction.IsForward()) return false; diff --git a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp index 838b7d4d7d..692b2aeba2 100644 --- a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp @@ -58,7 +58,7 @@ bool ConvMlirIgemmFwdXdlops::IsApplicable(const ConvolutionContext& ctx, #if MIOPEN_USE_MLIR if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_MLIR_IGEMM_FWD_XDLOPS{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!IsXdlopsSupport(ctx)) return false; diff --git a/src/solver/conv_mlir_igemm_wrw.cpp b/src/solver/conv_mlir_igemm_wrw.cpp index 3683fa95fa..bb1e1229b2 100644 --- a/src/solver/conv_mlir_igemm_wrw.cpp +++ b/src/solver/conv_mlir_igemm_wrw.cpp @@ -44,7 +44,7 @@ bool ConvMlirIgemmWrW::IsApplicable(const ConvolutionContext& ctx, #if MIOPEN_USE_MLIR if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_MLIR_IGEMM_WRW{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!problem.direction.IsBackwardWrW()) return false; diff --git a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp index 11145dcabb..34c99d39cd 100644 --- a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp @@ -45,7 +45,7 @@ bool ConvMlirIgemmWrWXdlops::IsApplicable(const ConvolutionContext& ctx, #if MIOPEN_USE_MLIR if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_MLIR_IGEMM_WRW_XDLOPS{})) return false; - if(problem.conv_problem.GetConv().attribute.deterministic) + if(problem.GetConv().attribute.deterministic) return false; if(!IsXdlopsSupport(ctx)) return false; diff --git a/src/solver/conv_multipass_wino3x3WrW.cpp b/src/solver/conv_multipass_wino3x3WrW.cpp index aedc5f6133..e41a434253 100644 --- a/src/solver/conv_multipass_wino3x3WrW.cpp +++ b/src/solver/conv_multipass_wino3x3WrW.cpp @@ -68,15 +68,15 @@ MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_WORKSPACE_MAX) GetSolverWinoXformHWSize(problem, 1); #define DEFINE_SHADER_ALIASES(problem) \ - const auto C = (problem).GetBatchSize(); \ - const auto N = (problem).GetOutChannels(); \ - const auto K = (problem).GetInChannels(); \ - const auto out_H = (problem).GetWeightsHeight(); \ - const auto out_W = (problem).GetWeightsWidth(); \ - const auto R = (problem).GetInHeight(); \ - const auto S = (problem).GetInWidth(); \ - const auto H = (problem).GetOutHeight(); \ - const auto W = (problem).GetOutWidth(); \ + const int C = (problem).GetBatchSize_(); \ + const int N = (problem).GetOutChannels_(); \ + const int K = (problem).GetInChannels_(); \ + const int out_H = (problem).GetWeightsHeight_(); \ + const int out_W = (problem).GetWeightsWidth_(); \ + const int R = (problem).GetInHeight_(); \ + const int S = (problem).GetInWidth_(); \ + const int H = (problem).GetOutHeight_(); \ + const int W = (problem).GetOutWidth_(); \ DEFINE_GETXFORMHWSIZE(problem) template @@ -455,7 +455,7 @@ bool ConvWinograd3x3MultipassWrW if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx9")) || StartsWith(name, "gfx94")) return false; - if(name == "gfx90a" && problem.conv_problem.IsGfx90aFp16altRequired()) + if(name == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; { @@ -496,26 +496,26 @@ bool ConvWinograd3x3MultipassWrW // clang-format off { - const int64_t input_line_size = static_cast(4) * problem.GetInWidth(); - const int64_t input_feature_map_size = input_line_size * problem.GetInHeight(); - const int64_t input_stack_size = input_feature_map_size * problem.GetInChannels(); + const int64_t input_line_size = static_cast(4) * problem.GetInWidth_(); + const int64_t input_feature_map_size = input_line_size * problem.GetInHeight_(); + const int64_t input_stack_size = input_feature_map_size * problem.GetInChannels_(); if (! (input_stack_size < (1U << 24))) return false; } bool ok = ( - (problem.GetWeightsWidth() == WinoDataW && problem.GetWeightsHeight() == WinoDataH) + (problem.GetWeightsWidth_() == WinoDataW && problem.GetWeightsHeight_() == WinoDataH) && (problem.GetKernelStrideW() == 1 || - (problem.GetKernelStrideW() == 2 && problem.GetWeightsHeight() == 3 && problem.GetWeightsWidth() == 3) + (problem.GetKernelStrideW() == 2 && problem.GetWeightsHeight_() == 3 && problem.GetWeightsWidth_() == 3) ) && problem.GetKernelStrideH() == problem.GetKernelStrideW() && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 - && problem.GetBatchSize() < std::pow(2, 24) - && problem.GetInChannels() < std::pow(2, 24) - && problem.GetOutChannels() < std::pow(2, 24) - && problem.GetInHeight() < std::pow(2, 24) - && problem.GetInWidth() < std::pow(2, 24) + && problem.GetBatchSize_() < std::pow(2, 24) + && problem.GetInChannels_() < std::pow(2, 24) + && problem.GetOutChannels_() < std::pow(2, 24) + && problem.GetInHeight_() < std::pow(2, 24) + && problem.GetInWidth_() < std::pow(2, 24) && problem.GetBias() == 0 && problem.GetInLayout() == "NCHW" && problem.GetGroupCount() == 1); @@ -676,7 +676,7 @@ ConvWinograd3x3MultipassWrW::Pre // clang-format off GemmDescriptor wino_gemm_desc{false,false,true,m,n,k, lda,ldb,ldc,batch_count,strideA,strideB, - strideC,alpha,beta,in_data_type, problem.conv_problem.GetConv().attribute.deterministic}; + strideC,alpha,beta,in_data_type, problem.GetConv().attribute.deterministic}; CallGemmStridedBatched(handle, wino_gemm_desc, diff --git a/src/solver/conv_ocl_dir2D11x11.cpp b/src/solver/conv_ocl_dir2D11x11.cpp index 7a6a4080a2..ec56fe9f56 100644 --- a/src/solver/conv_ocl_dir2D11x11.cpp +++ b/src/solver/conv_ocl_dir2D11x11.cpp @@ -58,7 +58,7 @@ bool ConvOclDirectFwd11x11::IsApplicable(const ConvolutionContext& ctx, return problem.direction.IsForward() && problem.GetGroupCount() == 1 && problem.GetDilationH() == 1 && problem.GetDilationW() == 1 && - problem.GetWeightsHeight() == 11 && problem.GetWeightsWidth() == 11 && + problem.GetWeightsHeight_() == 11 && problem.GetWeightsWidth_() == 11 && problem.GetKernelStrideH() == 4 && problem.GetKernelStrideW() == 4; } @@ -72,23 +72,22 @@ ConvSolution ConvOclDirectFwd11x11::GetSolution(const ConvolutionContext& ctx, // auto dev_local_mem_sz = localMemSize; // in bytes // major parameters int LG2_WAVE_SZ = mloLg2(hw_wave_sz); - int wei_cstride = problem.GetWeightsWidth() * problem.GetWeightsHeight(); + int wei_cstride = problem.GetWeightsWidth_() * problem.GetWeightsHeight_(); int wei_bstride = - (is_forward ? problem.GetInChannels() : problem.GetOutChannels()) * wei_cstride; + (is_forward ? problem.GetInChannels_() : problem.GetOutChannels_()) * wei_cstride; // number of batch iterations - result.n_stacks = 1; - result.n_stacks = std::min(problem.GetBatchSize(), result.n_stacks); + result.n_stacks = std::min(problem.GetBatchSize_(), 1U); // defines how to proceed : 1 grouop per batch or with a loop over all batches // loop over al batches make sense in 2 cases: a lot of small inputs/outputs or few batches // param int N_BATCH_LOOPS = 1; // (_n_inputs*_n_outputs <= 8 * 1024) ? 1 : _batch_sz / _n_stacks; - int n_batch_blks = (problem.GetBatchSize() + N_BATCH_LOOPS * result.n_stacks - 1) / + int n_batch_blks = (problem.GetBatchSize_() + N_BATCH_LOOPS * result.n_stacks - 1) / (N_BATCH_LOOPS * result.n_stacks); - int N_FILTER_SPLITS0 = - ((problem.GetWeightsWidth() + problem.GetKernelStrideW() - 1) / problem.GetKernelStrideW()); - int N_FILTER_SPLITS1 = ((problem.GetWeightsHeight() + problem.GetKernelStrideH() - 1) / + int N_FILTER_SPLITS0 = ((problem.GetWeightsWidth_() + problem.GetKernelStrideW() - 1) / + problem.GetKernelStrideW()); + int N_FILTER_SPLITS1 = ((problem.GetWeightsHeight_() + problem.GetKernelStrideH() - 1) / problem.GetKernelStrideH()); static const int data_multiplier0 = @@ -128,9 +127,10 @@ ConvSolution ConvOclDirectFwd11x11::GetSolution(const ConvolutionContext& ctx, // generate full output width // extent1 == MLO_GRP_SZ / MLO_PROCESING_WIDTH int PROCESING_WIDTH = - ((problem.GetOutWidth() + result.out_pix_tile0 - 1) / result.out_pix_tile0); + ((problem.GetOutWidth_() + result.out_pix_tile0 - 1) / result.out_pix_tile0); - int OUT_EXTENT1 = std::min(problem.GetOutHeight(), (GRP_SZ / PROCESING_WIDTH)); + int OUT_EXTENT1 = + std::min(static_cast(problem.GetOutHeight_()), (GRP_SZ / PROCESING_WIDTH)); // define a special size for a specific width as a devisor to avoid dealing with out of range // param @@ -150,16 +150,17 @@ ConvSolution ConvOclDirectFwd11x11::GetSolution(const ConvolutionContext& ctx, // n_in_stacks input map will be written in the local memory. int n_in_stacks = 1; - n_in_stacks = std::min(problem.GetInChannels(), n_in_stacks); - n_out_stacks = std::min(problem.GetOutChannels(), n_out_stacks); + n_in_stacks = std::min(static_cast(problem.GetInChannels_()), n_in_stacks); + n_out_stacks = std::min(static_cast(problem.GetOutChannels_()), n_out_stacks); // param // 6 get us the min // cppcheck-suppress knownConditionTrueFalse - static const int backwards_min_output = (data_multiplier1 > 1 || data_multiplier0 > 1) ? 1 : 4; + static const unsigned backwards_min_output = + (data_multiplier1 > 1 || data_multiplier0 > 1) ? 1 : 4; result.n_out_pix_tiles = - (is_forward) ? std::min(6, (problem.GetOutChannels() + n_out_stacks - 1) / n_out_stacks) - : std::min(problem.GetOutChannels(), backwards_min_output); + (is_forward) ? std::min(6U, (problem.GetOutChannels_() + n_out_stacks - 1) / n_out_stacks) + : std::min(problem.GetOutChannels_(), backwards_min_output); // number of maps in a stack or number of input read blocks written into 1 wk-item (lane) // param @@ -174,27 +175,27 @@ ConvSolution ConvOclDirectFwd11x11::GetSolution(const ConvolutionContext& ctx, int grp_tile2 = 1; // second pass if needed - int n_extents = ((problem.GetOutHeight() + OUT_EXTENT1 - 1) / OUT_EXTENT1); - int n_output_map_blocks = ((problem.GetOutChannels() + total_out_maps - 1) / total_out_maps); - int last_out_extent1 = - problem.GetOutHeight() - (std::max(1, problem.GetOutHeight() / OUT_EXTENT1) * OUT_EXTENT1); + int n_extents = ((problem.GetOutHeight_() + OUT_EXTENT1 - 1) / OUT_EXTENT1); + int n_output_map_blocks = ((problem.GetOutChannels_() + total_out_maps - 1) / total_out_maps); + int last_out_extent1 = problem.GetOutHeight_() - + (std::max(1U, problem.GetOutHeight_() / OUT_EXTENT1) * OUT_EXTENT1); last_out_extent1 = (last_out_extent1 < 0) ? 0 : last_out_extent1; int n_batches_pass2 = 1; bool second_pass = false; if(is_forward && 0 < last_out_extent1 && last_out_extent1 <= OUT_EXTENT1 / 2) { - n_extents = std::max(1, problem.GetOutHeight() / OUT_EXTENT1); + n_extents = std::max(1U, problem.GetOutHeight_() / OUT_EXTENT1); n_batches_pass2 = std::max(1, GRP_SZ / (PROCESING_WIDTH * last_out_extent1)); second_pass = true; } // calc bwd grid - int n_out_pix_tiles1 = - (problem.GetOutHeight() + result.out_pix_tile1 - 1 + 2 * problem.GetPadH()) / - result.out_pix_tile1; - int n_out_pix_tiles0 = - (problem.GetOutWidth() + result.out_pix_tile0 - 1 + 2 * problem.GetPadW()) / - result.out_pix_tile0; + int n_out_pix_tiles1 = (static_cast(problem.GetOutHeight_()) + result.out_pix_tile1 - 1 + + 2 * problem.GetPadH()) / + result.out_pix_tile1; + int n_out_pix_tiles0 = (static_cast(problem.GetOutWidth_()) + result.out_pix_tile0 - 1 + + 2 * problem.GetPadW()) / + result.out_pix_tile0; int n_out_pix_tiles = n_out_pix_tiles1 * n_out_pix_tiles0; // calculate lcl mem size for backward data @@ -225,30 +226,30 @@ ConvSolution ConvOclDirectFwd11x11::GetSolution(const ConvolutionContext& ctx, std::to_string(result.grp_tile0) + std::string(" -DMLO_GRP_SZ1=") + std::to_string(result.grp_tile1) + std::string(" -DMLO_GRP_SZ2=") + std::to_string(grp_tile2) + std::string(" -DMLO_FILTER_SIZE0=") + - std::to_string(problem.GetWeightsWidth()) + std::string(" -DMLO_FILTER_SIZE1=") + - std::to_string(problem.GetWeightsHeight()) + std::string(" -DMLO_FILTER_PAD0=") + + std::to_string(problem.GetWeightsWidth_()) + std::string(" -DMLO_FILTER_SIZE1=") + + std::to_string(problem.GetWeightsHeight_()) + std::string(" -DMLO_FILTER_PAD0=") + std::to_string(problem.GetPadW()) + std::string(" -DMLO_FILTER_PAD1=") + std::to_string(problem.GetPadH()) + std::string(" -DMLO_FILTER_STRIDE0=") + std::to_string(problem.GetKernelStrideW()) + std::string(" -DMLO_FILTER_STRIDE1=") + std::to_string(problem.GetKernelStrideH()) + std::string(" -DSTRIDE_W=") + std::to_string(problem.GetKernelStrideW()) + std::string(" -DSTRIDE_H=") + std::to_string(problem.GetKernelStrideH()) + std::string(" -DMLO_N_OUTPUTS=") + - std::to_string(problem.GetOutChannels()) + std::string(" -DMLO_N_INPUTS=") + - std::to_string(problem.GetInChannels()) + std::string(" -DMLO_BATCH_SZ=") + - std::to_string(problem.GetBatchSize()) + std::string(" -DMLO_N_BATCH_LOOPS=") + + std::to_string(problem.GetOutChannels_()) + std::string(" -DMLO_N_INPUTS=") + + std::to_string(problem.GetInChannels_()) + std::string(" -DMLO_BATCH_SZ=") + + std::to_string(problem.GetBatchSize_()) + std::string(" -DMLO_N_BATCH_LOOPS=") + std::to_string(N_BATCH_LOOPS) + std::string(" -DMLO_OUT_BATCH_STRIDE=") + - std::to_string(problem.GetOutBatchStride()) + std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + - std::to_string(problem.GetOutChannelStride()) + std::string(" -DMLO_OUT_STRIDE=") + - std::to_string(problem.GetOutStride()) + std::string(" -DMLO_IN_BATCH_STRIDE=") + - std::to_string(problem.GetInBatchStride()) + std::string(" -DMLO_IN_CHANNEL_STRIDE=") + - std::to_string(problem.GetInChannelStride()) + std::string(" -DMLO_IN_STRIDE=") + - std::to_string(problem.GetInStride()) + std::string(" -DMLO_WEI_BATCH_STRIDE=") + + std::to_string(problem.GetOutBatchStride_()) + std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + + std::to_string(problem.GetOutChannelStride_()) + std::string(" -DMLO_OUT_STRIDE=") + + std::to_string(problem.GetOutStrideH_()) + std::string(" -DMLO_IN_BATCH_STRIDE=") + + std::to_string(problem.GetInBatchStride_()) + std::string(" -DMLO_IN_CHANNEL_STRIDE=") + + std::to_string(problem.GetInChannelStride_()) + std::string(" -DMLO_IN_STRIDE=") + + std::to_string(problem.GetInStrideH_()) + std::string(" -DMLO_WEI_BATCH_STRIDE=") + std::to_string(wei_bstride) + std::string(" -DMLO_WEI_CHANNEL_STRIDE=") + std::to_string(wei_cstride) + std::string(" -DMLO_IN_WIDTH=") + - std::to_string(problem.GetInWidth()) + std::string(" -DMLO_IN_HEIGHT=") + - std::to_string(problem.GetInHeight()) + std::string(" -DMLO_OUT_WIDTH=") + - std::to_string(problem.GetOutWidth()) + std::string(" -DMLO_OUT_HEIGHT=") + - std::to_string(problem.GetOutHeight()) + std::string(" -DMLO_IN_TILE1=") + + std::to_string(problem.GetInWidth_()) + std::string(" -DMLO_IN_HEIGHT=") + + std::to_string(problem.GetInHeight_()) + std::string(" -DMLO_OUT_WIDTH=") + + std::to_string(problem.GetOutWidth_()) + std::string(" -DMLO_OUT_HEIGHT=") + + std::to_string(problem.GetOutHeight_()) + std::string(" -DMLO_IN_TILE1=") + std::to_string(result.in_tile1) + std::string(" -DMLO_IN_TILE0=") + std::to_string(result.in_tile0) + std::string(" -DMLO_N_LCL_BATCHS=") + std::to_string(result.n_stacks) // # of diff stacks (part of batch). @@ -326,7 +327,7 @@ ConvSolution ConvOclDirectFwd11x11::GetSolution(const ConvolutionContext& ctx, size_t gbl_wk0 = GRP_SZ; size_t gbl_wk1 = n_output_map_blocks; - n_batch_blks = (problem.GetBatchSize() + n_batches_pass2 - 1) / n_batches_pass2; + n_batch_blks = (problem.GetBatchSize_() + n_batches_pass2 - 1) / n_batches_pass2; size_t gbl_wk2 = n_batch_blks; construction_parameters.g_wk.push_back(gbl_wk0); diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp index 4cce72c9e5..562e98c366 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp @@ -67,12 +67,12 @@ bool ConvOclBwdWrW1x1::IsApplicable(const ConvolutionContext& ctx, return false; } - bool result = (problem.GetWeightsWidth() == 1 && problem.GetWeightsHeight() == 1 && + bool result = (problem.GetWeightsWidth_() == 1 && problem.GetWeightsHeight_() == 1 && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 && problem.GetGroupCount() == 1); // Does not support strides > 1 if not multiple of 16 - if((problem.GetInChannels() & 0xF) > 0 || (problem.GetOutChannels() & 0xF) > 0) + if((problem.GetInChannels_() & 0xF) > 0 || (problem.GetOutChannels_() & 0xF) > 0) result = false; return result; @@ -82,7 +82,8 @@ static inline int GetNPasses(const ProblemDescription& problem) { const int n_passes = #if TWO_PASSES - ((problem.GetBatchSize() >= 16 || 2 * problem.GetOutChannels() > problem.GetInChannels()) && + ((problem.GetBatchSize_() >= 16 || + 2 * problem.GetOutChannels_() > problem.GetInChannels_()) && problem.GetPadH() == 0 && problem.GetPadW() == 0 && (problem.GetKernelStrideW() > 1 || problem.GetKernelStrideH() > 1)) ? 2 @@ -96,13 +97,13 @@ size_t ConvOclBwdWrW1x1::GetWorkspaceSize(const ConvolutionContext&, const ProblemDescription& problem) const { const int n_passes = GetNPasses(problem); - if(((problem.GetInChannels() & 0xF) == 0 && (problem.GetOutChannels() & 0xF) == 0) && + if(((problem.GetInChannels_() & 0xF) == 0 && (problem.GetOutChannels_() & 0xF) == 0) && (n_passes > 1 && problem.GetPadH() == 0 && problem.GetPadW() == 0 && (problem.GetKernelStrideW() > 1 || problem.GetKernelStrideH() > 1))) { - const auto in_channel_stride = problem.GetInStride() * problem.GetInHeight(); - const auto in_batch_stride = in_channel_stride * problem.GetOutChannels(); - return GetTypeSize(problem.GetOutDataType()) * in_batch_stride * problem.GetBatchSize(); + const auto in_channel_stride = problem.GetInStrideH_() * problem.GetInHeight_(); + const auto in_batch_stride = in_channel_stride * problem.GetOutChannels_(); + return GetTypeSize(problem.GetOutDataType()) * in_batch_stride * problem.GetBatchSize_(); } else return 0; @@ -116,15 +117,14 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, // FIX ME! FIX ME! FIX ME! Does not support C, K != 16X yet // NON-Stride/PAD mode NON-16X will be supported by MIOpenConvBwdWrW1x1.CL - if((problem.GetInChannels() & 0xF) == 0 && (problem.GetOutChannels() & 0xF) == 0) + if((problem.GetInChannels_() & 0xF) == 0 && (problem.GetOutChannels_() & 0xF) == 0) { - // problem.GetInChannels()==> C - // problem.GetOutChannels()==>K + // problem.GetInChannels_()==> C + // problem.GetOutChannels_()==>K // Jian: following kernel uses C as input, K as output, different from original definition // FIX ME! FIX ME! FIX ME! // JIANYANG: not know the meaning of following ==> - result.n_stacks = 1; - result.n_stacks = std::min(problem.GetBatchSize(), result.n_stacks); + result.n_stacks = std::min(problem.GetBatchSize_(), 1U); result.out_pix_tile0 = 1; result.out_pix_tile1 = 1; result.in_tile1 = 1; @@ -134,13 +134,13 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, // 8/16/64 int n_lcl_in_maps = 8; - /*if(4 *((problem.GetOutChannels()+63)/64) * ((problem.GetInChannels()+63)/64) >=512) + /*if(4 *((problem.GetOutChannels_()+63)/64) * ((problem.GetInChannels_()+63)/64) >=512) { n_lcl_in_maps =64; } else */ - if(4 * ((problem.GetOutChannels() + 15) / 16) * ((problem.GetInChannels() + 15) / 16) >= + if(4 * ((problem.GetOutChannels_() + 15) / 16) * ((problem.GetInChannels_() + 15) / 16) >= 512) { n_lcl_in_maps = 16; @@ -151,8 +151,8 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, int n_grp_size0 = 64; - int n_out_blocks = ((problem.GetInChannels() + n_lcl_out_maps - 1) / n_lcl_out_maps); - int n_in_blocks = ((problem.GetOutChannels() + n_lcl_in_maps - 1) / n_lcl_in_maps); + int n_out_blocks = ((problem.GetInChannels_() + n_lcl_out_maps - 1) / n_lcl_out_maps); + int n_in_blocks = ((problem.GetOutChannels_() + n_lcl_in_maps - 1) / n_lcl_in_maps); int total_waves = n_in_blocks * n_out_blocks; result.n_out_pix_tiles = n_lcl_out_maps; @@ -219,35 +219,35 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, int read_unit = 4; // subsampled input - int in_width = (n_passes > 1) ? problem.GetInWidth() : problem.GetOutWidth(); - int in_height = (n_passes > 1) ? problem.GetInHeight() : problem.GetOutHeight(); - int in_stride = (n_passes > 1) ? problem.GetInStride() : problem.GetOutStride(); + int in_width = (n_passes > 1) ? problem.GetInWidth_() : problem.GetOutWidth_(); + int in_height = (n_passes > 1) ? problem.GetInHeight_() : problem.GetOutHeight_(); + int in_stride = (n_passes > 1) ? problem.GetInStrideH_() : problem.GetOutStrideH_(); int in_channel_stride = - (n_passes > 1) ? in_stride * in_height : problem.GetOutChannelStride(); - int in_batch_stride = (n_passes > 1) ? in_channel_stride * problem.GetOutChannels() - : problem.GetOutBatchStride(); - int out_batch_stride = problem.GetInBatchStride(); - int out_channel_stride = problem.GetInChannelStride(); - int out_stride = problem.GetInStride(); - int wei_batch_stride = problem.GetInChannels() * problem.GetOutChannels() * - problem.GetWeightsWidth() * problem.GetWeightsHeight(); + (n_passes > 1) ? in_stride * in_height : problem.GetOutChannelStride_(); + int in_batch_stride = (n_passes > 1) ? in_channel_stride * problem.GetOutChannels_() + : problem.GetOutBatchStride_(); + int out_batch_stride = problem.GetInBatchStride_(); + int out_channel_stride = problem.GetInChannelStride_(); + int out_stride = problem.GetInStrideH_(); + int wei_batch_stride = problem.GetInChannels_() * problem.GetOutChannels_() * + problem.GetWeightsWidth_() * problem.GetWeightsHeight_(); int wei_channel_stride = - problem.GetOutChannels() * problem.GetWeightsWidth() * problem.GetWeightsHeight(); - int max_loads_per_readunit = (out_channel_stride / read_unit) * problem.GetBatchSize(); + problem.GetOutChannels_() * problem.GetWeightsWidth_() * problem.GetWeightsHeight_(); + int max_loads_per_readunit = (out_channel_stride / read_unit) * problem.GetBatchSize_(); // limited shape size shows better performance with ead_uint == 3 /* if( (out_channel_stride % 3) == 1) { read_unit = 3; - max_loads_per_readunit = (out_channel_stride / read_unit) * problem.GetBatchSize(); + max_loads_per_readunit = (out_channel_stride / read_unit) * problem.GetBatchSize_(); } */ int out_pad_min_x = 0; int out_pad_min_y = 0; - int out_pad_width = problem.GetInWidth(); - int out_pad_height = problem.GetInHeight(); + int out_pad_width = problem.GetInWidth_(); + int out_pad_height = problem.GetInHeight_(); int in_pad_min_x = 0; int in_pad_min_y = 0; @@ -261,9 +261,9 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, out_pad_min_x = (problem.GetPadW() + problem.GetKernelStrideW() - 1) / problem.GetKernelStrideW(); - out_pad_width = - (problem.GetOutWidth() - in_pad_min_x + problem.GetKernelStrideW() - 1) / - problem.GetKernelStrideW(); + out_pad_width = (static_cast(problem.GetOutWidth_()) - in_pad_min_x + + problem.GetKernelStrideW() - 1) / + problem.GetKernelStrideW(); } if(problem.GetPadH() > 0) { @@ -274,9 +274,9 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, out_pad_min_y = (problem.GetPadH() + problem.GetKernelStrideH() - 1) / problem.GetKernelStrideH(); - out_pad_height = - (problem.GetOutHeight() - in_pad_min_y + problem.GetKernelStrideH() - 1) / - problem.GetKernelStrideH(); + out_pad_height = (static_cast(problem.GetOutHeight_()) - in_pad_min_y + + problem.GetKernelStrideH() - 1) / + problem.GetKernelStrideH(); } if(problem.GetPadW() > 0 || problem.GetPadH() > 0 || @@ -289,8 +289,8 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, // read_unit = (out_pad_width % 7 == 0) ? 7 : (out_pad_width % 5 == 0) ? 5 : // (out_pad_width % 4 == 0) ? 4 : (out_pad_width % 3 == 0) ? 3 : (out_pad_width % 2 // == 0) ? 2 : 1; - max_loads_per_readunit = - (out_pad_width / read_unit) * out_pad_height * problem.GetBatchSize(); + max_loads_per_readunit = (out_pad_width / read_unit) * out_pad_height * + static_cast(problem.GetBatchSize_()); } int kernel_stride_w = problem.GetKernelStrideW(); @@ -317,9 +317,9 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, : 1; int n_grp0_size0 = 256; // real input strides - int in0_stride = problem.GetOutStride(); - int in0_channel_stride = problem.GetOutChannelStride(); - int in0_batch_stride = problem.GetOutBatchStride(); + int in0_stride = problem.GetOutStrideH_(); + int in0_channel_stride = problem.GetOutChannelStride_(); + int in0_batch_stride = problem.GetOutBatchStride_(); int kernel0_stride0 = problem.GetKernelStrideW(); int kernel0_stride1 = problem.GetKernelStrideH(); @@ -328,21 +328,21 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, std::string(" -DMLO_GRP_SZ1=1 ") + std::string(" -DMLO_GRP_SZ2=1 ") + std::string(" -DMLO_GRP0_SZ0=") + std::to_string(n_grp0_size0) + std::string(" -DMLO_GRP0_SZ1=1 ") + std::string(" -DMLO_GRP0_SZ2=1 ") + - std::string(" -DMLO_FILTER_SIZE0=") + std::to_string(problem.GetWeightsWidth()) + - std::string(" -DMLO_FILTER_SIZE1=") + std::to_string(problem.GetWeightsHeight()) + + std::string(" -DMLO_FILTER_SIZE0=") + std::to_string(problem.GetWeightsWidth_()) + + std::string(" -DMLO_FILTER_SIZE1=") + std::to_string(problem.GetWeightsHeight_()) + std::string(" -DMLO_FILTER_PAD0=") + std::to_string(problem.GetPadW()) + std::string(" -DMLO_FILTER_PAD1=") + std::to_string(problem.GetPadH()) + std::string(" -DMLO_FILTER_STRIDE0=") + std::to_string(kernel_stride_w) + std::string(" -DMLO_FILTER_STRIDE1=") + std::to_string(kernel_stride_h) + std::string(" -DMLO_FILTER0_STRIDE0=") + std::to_string(kernel0_stride0) + std::string(" -DMLO_FILTER0_STRIDE1=") + std::to_string(kernel0_stride1) + - std::string(" -DMLO_N_OUTPUTS=") + std::to_string(problem.GetInChannels()) + - std::string(" -DMLO_N_INPUTS=") + std::to_string(problem.GetOutChannels()) + - std::string(" -DMLO_BATCH_SZ=") + std::to_string(problem.GetBatchSize()) + + std::string(" -DMLO_N_OUTPUTS=") + std::to_string(problem.GetInChannels_()) + + std::string(" -DMLO_N_INPUTS=") + std::to_string(problem.GetOutChannels_()) + + std::string(" -DMLO_BATCH_SZ=") + std::to_string(problem.GetBatchSize_()) + std::string(" -DMLO_IN_WIDTH=") + std::to_string(in_width) + std::string(" -DMLO_IN_HEIGHT=") + std::to_string(in_height) + - std::string(" -DMLO_OUT_WIDTH=") + std::to_string(problem.GetInWidth()) + - std::string(" -DMLO_OUT_HEIGHT=") + std::to_string(problem.GetInHeight()) + + std::string(" -DMLO_OUT_WIDTH=") + std::to_string(problem.GetInWidth_()) + + std::string(" -DMLO_OUT_HEIGHT=") + std::to_string(problem.GetInHeight_()) + std::string(" -DMLO_N_LOAD_DWORDS_PER_MAP_ONCE=") + std::to_string(n_load_dwords_per_map_once) + std::string(" -DMLO_N_LCL_IN_MAPS=") + std::to_string(n_lcl_in_maps) + std::string(" -DMLO_N_LCL_OUT_MAPS=") + @@ -386,7 +386,7 @@ ConvSolution ConvOclBwdWrW1x1::GetSolution(const ConvolutionContext& ctx, kernel.l_wk.push_back(1); // output is number of subsampled input maps size_t gbl_wk0 = (in_batch_stride / write_unit); - size_t gbl_wk1 = problem.GetBatchSize(); + size_t gbl_wk1 = problem.GetBatchSize_(); size_t gbl_wk2 = 1; kernel.g_wk.push_back(gbl_wk0); diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp index 1ee61ad356..d24eb17320 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp @@ -135,8 +135,8 @@ inline static bool Inc_2_to_11_optimized(int& v) static bool IsTunable(const ProblemDescription& problem) { return !(problem.GetGroupCount() == 1 && - ((problem.GetWeightsWidth() == 3 && problem.GetWeightsHeight() == 3) || - (problem.GetWeightsWidth() == 1 && problem.GetWeightsHeight() == 1))); + ((problem.GetWeightsWidth_() == 3 && problem.GetWeightsHeight_() == 3) || + (problem.GetWeightsWidth_() == 1 && problem.GetWeightsHeight_() == 1))); } bool ConvOclBwdWrW2NonTunable::IsApplicable(const ConvolutionContext& ctx, @@ -225,7 +225,7 @@ static const int N_STACKS = 1; // number of batch iterations template static size_t GetNBatchBlks(const ProblemDescription& problem) { - return std::ceil(static_cast(problem.GetBatchSize()) / (N_BATCH_LOOPS * N_STACKS)); + return std::ceil(static_cast(problem.GetBatchSize_()) / (N_BATCH_LOOPS * N_STACKS)); } template @@ -244,26 +244,26 @@ bool PerformanceConfigConvOclBwdWrw2::IsValid( // Ensure that the total amount of system memory used by intermediate object // that holds the weights of x number of batches doesn't exceed system memory size_t wei_cstride = - static_cast(problem.GetWeightsHeight()) * problem.GetWeightsWidth(); - size_t wei_bstride = (problem.GetOutChannels() / problem.GetGroupCount()) * wei_cstride; + static_cast(problem.GetWeightsHeight_()) * problem.GetWeightsWidth_(); + size_t wei_bstride = (problem.GetOutChannels_() / problem.GetGroupCount()) * wei_cstride; // number of batch iterations const size_t n_batch_blks = GetNBatchBlks(problem); // guard not to grab too much system memory - if(n_batch_blks < 1 || (wei_bstride * problem.GetInChannels() * n_batch_blks) > + if(n_batch_blks < 1 || (wei_bstride * problem.GetInChannels_() * n_batch_blks) > ctx.GetStream().GetMaxMemoryAllocSize()) { return false; } // Check 2: read size - if(problem.GetInWidth() < read_size) + if(problem.GetInWidth_() < read_size) { return false; } size_t aligned_out_scan_lane = - std::ceil(static_cast(problem.GetInWidth()) / read_size); // image aligned scan + std::ceil(static_cast(problem.GetInWidth_()) / read_size); // image aligned scan // Check 3: n_out_channels_tiles if(problem.GetGroupCount() > 1 && n_out_channels_tiles > 1) @@ -271,7 +271,7 @@ bool PerformanceConfigConvOclBwdWrw2::IsValid( return false; } - size_t n_output_channels_per_group = problem.GetInChannels() / problem.GetGroupCount(); + size_t n_output_channels_per_group = problem.GetInChannels_() / problem.GetGroupCount(); // Check 4: n_out_channels_per_tile if(problem.GetGroupCount() > 1 && n_out_channels_per_tile > n_output_channels_per_group) @@ -293,14 +293,14 @@ bool PerformanceConfigConvOclBwdWrw2::IsValid( return false; } - if(n_out_rows_in_lcl < problem.GetWeightsHeight()) + if(n_out_rows_in_lcl < problem.GetWeightsHeight_()) { return false; } // Check 5: n_out_rows_in_lcl should exceed LDS limit size_t in_lcl_height = - (n_out_rows_in_lcl - 1) * problem.GetKernelStrideH() + problem.GetWeightsHeight(); + (n_out_rows_in_lcl - 1) * problem.GetKernelStrideH() + problem.GetWeightsHeight_(); size_t in_lcl_sz = 0; { // Chao: Reserve space in LDS for left padding, it also reserve @@ -310,13 +310,13 @@ bool PerformanceConfigConvOclBwdWrw2::IsValid( // Also, for the last row, right padding is needed. // Revisit this if encounter failure size_t in_lcl_width = 0; - size_t in_width = problem.GetOutWidth(); // out is in, in is out - size_t out_width = problem.GetInWidth(); + size_t in_width = problem.GetOutWidth_(); // out is in, in is out + size_t out_width = problem.GetInWidth_(); size_t in_lcl_width_effective = std::max( in_width + 2ULL * problem.GetPadW(), std::max(problem.GetPadW() + ((in_width + read_size - 1) / read_size) * read_size, - problem.GetWeightsWidth() + (out_width - 1) * problem.GetKernelStrideW())); + problem.GetWeightsWidth_() + (out_width - 1) * problem.GetKernelStrideW())); size_t in_lcl_width_right_buffer = std::max( static_cast(in_lcl_width_effective - (in_width + 2ULL * problem.GetPadW())), 0); @@ -335,10 +335,10 @@ bool PerformanceConfigConvOclBwdWrw2::IsValid( } // check LDS consumption - size_t wei_per_wkitem = (problem.GetWeightsWidth() <= 7 || - (((problem.GetWeightsWidth() / 2) * 2) != problem.GetWeightsWidth())) - ? problem.GetWeightsWidth() - : problem.GetWeightsWidth() / 2; + size_t wei_per_wkitem = (problem.GetWeightsWidth_() <= 7 || + (((problem.GetWeightsWidth_() / 2) * 2) != problem.GetWeightsWidth_())) + ? problem.GetWeightsWidth_() + : problem.GetWeightsWidth_() / 2; { size_t n_lcl_batchs = N_STACKS; @@ -352,7 +352,7 @@ bool PerformanceConfigConvOclBwdWrw2::IsValid( size_t wei_lcl_sz = 0; size_t max_wei_blk = 0; size_t out_wei_scan_loop = 0; - size_t out_width = problem.GetInWidth(); // out is in, in is out + size_t out_width = problem.GetInWidth_(); // out is in, in is out { const auto hw_wave_size = 64; // TBD Obtain this from handle. @@ -361,8 +361,8 @@ bool PerformanceConfigConvOclBwdWrw2::IsValid( if(wei_per_wkitem == 0) return false; size_t wei_blk_sz0 = - std::ceil(static_cast(problem.GetWeightsWidth()) / wei_per_wkitem); - size_t wei_blk_sz = problem.GetWeightsHeight() * wei_blk_sz0; + std::ceil(static_cast(problem.GetWeightsWidth_()) / wei_per_wkitem); + size_t wei_blk_sz = problem.GetWeightsHeight_() * wei_blk_sz0; if(wei_blk_sz == 0) return false; size_t n_wei_blk = workgroup_size / wei_blk_sz; @@ -402,7 +402,7 @@ bool PerformanceConfigConvOclBwdWrw2::IsValid( { size_t data_len = GetTypeSize(problem.GetOutDataType()); result.workspace_sz = static_cast(wei_bstride) * - static_cast(problem.GetInChannels()) * n_batch_blks * + static_cast(problem.GetInChannels_()) * n_batch_blks * static_cast(data_len); #if WORKAROUND_ISSUE_1185 @@ -423,7 +423,7 @@ void PerformanceConfigConvOclBwdWrw2::HeuristicInit( { n_waves = 1; read_size = 6; - const auto n_output_channels_per_group = problem.GetInChannels() / problem.GetGroupCount(); + const auto n_output_channels_per_group = problem.GetInChannels_() / problem.GetGroupCount(); if(n_output_channels_per_group % 4 == 0) n_out_channels_per_tile = 4; else if(n_output_channels_per_group % 3 == 0) @@ -433,7 +433,7 @@ void PerformanceConfigConvOclBwdWrw2::HeuristicInit( else n_out_channels_per_tile = 1; n_out_channels_tiles = 1; - n_out_rows_in_lcl = problem.GetWeightsHeight(); + n_out_rows_in_lcl = problem.GetWeightsHeight_(); } template @@ -477,22 +477,23 @@ bool ConvOclBwdWrW2::IsApplicableBase(const ConvolutionContext& c // previous read, (MLO_N_ALIGNED_OUT_SCAN_BLK * MLO_FILTER_STRIDE1) of it is fresh read // from device memory. So (MLO_FILTER_SIZE1 - MLO_FILTER_STRIDE1) need no less than 0. // TODO: chao: revisit this if failure is encountered. - problem.GetWeightsHeight() - problem.GetKernelStrideH() >= 0 && + problem.GetWeightsHeight_() >= problem.GetKernelStrideH() && #endif // The first scan of stripe of the input into LDS will read a strip of height // (kernel_size_h - kernel_stride_h), this stripe should include the whole lower bound // padding, as well as some or none of the input. - problem.GetWeightsHeight() - problem.GetKernelStrideH() >= problem.GetPadH() && - problem.GetBatchSize() >= N_BATCH_LOOPS && + static_cast(problem.GetWeightsHeight_()) - problem.GetKernelStrideH() >= + problem.GetPadH() && + problem.GetBatchSize_() >= N_BATCH_LOOPS && /// \todo Workaround for issue 1693 - !(problem.GetWeightsWidth() >= 8 && problem.GetWeightsWidth() % 2 == 0 && + !(problem.GetWeightsWidth_() >= 8 && problem.GetWeightsWidth_() % 2 == 0 && !( // Allow these configs to avoid perf drops: (problem.GetKernelStrideH() == 2 && problem.GetKernelStrideW() == 2) && - (problem.GetWeightsHeight() == 5 && - (problem.GetWeightsWidth() == 10 || problem.GetWeightsWidth() == 20)) && - ((problem.GetOutHeight() == 79 && problem.GetOutWidth() == 341) || - (problem.GetOutHeight() == 161 && problem.GetOutWidth() == 700)))) && + (problem.GetWeightsHeight_() == 5 && + (problem.GetWeightsWidth_() == 10 || problem.GetWeightsWidth_() == 20)) && + ((problem.GetOutHeight_() == 79 && problem.GetOutWidth_() == 341) || + (problem.GetOutHeight_() == 161 && problem.GetOutWidth_() == 700)))) && /// Avoid LDS & Workspace over-allocation. /// \note Required LDS depends on PerformanceConfig. /// We use the default PerformanceConfig here. This guarantees that at least @@ -525,12 +526,12 @@ size_t ConvOclBwdWrW2::GetWorkspaceSize(const ConvolutionContext& const size_t n_batch_blks = GetNBatchBlks(problem); if(n_batch_blks > 1) { - const auto n_input_channels_per_group = problem.GetOutChannels() / problem.GetGroupCount(); - const auto wei_cstride = problem.GetWeightsWidth() * problem.GetWeightsHeight(); + const auto n_input_channels_per_group = problem.GetOutChannels_() / problem.GetGroupCount(); + const auto wei_cstride = problem.GetWeightsWidth_() * problem.GetWeightsHeight_(); const auto wei_bstride = n_input_channels_per_group * wei_cstride; int data_len = GetTypeSize(problem.GetOutDataType()); return static_cast(wei_bstride) * - static_cast(problem.GetInChannels()) * + static_cast(problem.GetInChannels_()) * static_cast(n_batch_blks) * static_cast(data_len); } else @@ -547,28 +548,28 @@ ConvSolution ConvOclBwdWrW2::GetSolution( const auto hw_wave_size = 64; const auto workgroup_size = hw_wave_size * config.n_waves; - const auto n_input_channels_per_group = problem.GetOutChannels() / problem.GetGroupCount(); - const auto n_output_channels_per_group = problem.GetInChannels() / problem.GetGroupCount(); - const auto wei_cstride = problem.GetWeightsWidth() * problem.GetWeightsHeight(); - const auto wei_bstride = n_input_channels_per_group * wei_cstride; + const int n_input_channels_per_group = problem.GetOutChannels_() / problem.GetGroupCount(); + const int n_output_channels_per_group = problem.GetInChannels_() / problem.GetGroupCount(); + const int wei_cstride = problem.GetWeightsWidth_() * problem.GetWeightsHeight_(); + const auto wei_bstride = n_input_channels_per_group * wei_cstride; result.n_in_data_tiles = 1; const size_t n_batch_blks = GetNBatchBlks(problem); size_t total_out_maps = config.n_out_channels_per_tile * config.n_out_channels_tiles; - size_t wei_per_wkitem = (problem.GetWeightsWidth() <= 7 || - (((problem.GetWeightsWidth() / 2) * 2) != problem.GetWeightsWidth())) - ? problem.GetWeightsWidth() - : problem.GetWeightsWidth() / 2; + size_t wei_per_wkitem = (problem.GetWeightsWidth_() <= 7 || + (((problem.GetWeightsWidth_() / 2) * 2) != problem.GetWeightsWidth_())) + ? problem.GetWeightsWidth_() + : problem.GetWeightsWidth_() / 2; // each wave is a filter row std::string READ_TYPE = (config.read_size == 1) ? "_FLOAT" : "_FLOAT" + std::to_string((config.read_size)); - size_t aligned_out_scan_lane = std::ceil(static_cast(problem.GetInWidth()) / + size_t aligned_out_scan_lane = std::ceil(static_cast(problem.GetInWidth_()) / config.read_size); // image aligned scan size_t n_out_blk = - std::ceil(static_cast(problem.GetInHeight()) / config.n_out_rows_in_lcl); + std::ceil(static_cast(problem.GetInHeight_()) / config.n_out_rows_in_lcl); size_t in_lcl_height = - (config.n_out_rows_in_lcl - 1) * problem.GetKernelStrideH() + problem.GetWeightsHeight(); + (config.n_out_rows_in_lcl - 1) * problem.GetKernelStrideH() + problem.GetWeightsHeight_(); size_t in_lcl_width = 0; size_t in_lcl_sz = 0; { @@ -578,8 +579,8 @@ ConvSolution ConvOclBwdWrW2::GetSolution( // is overlapped with the left padding of the next row. // Also, for the last row, right padding is needed. // Revisit this if encounter failure - size_t in_width = problem.GetOutWidth(); // out is in, in is out - size_t out_width = problem.GetInWidth(); + size_t in_width = problem.GetOutWidth_(); // out is in, in is out + size_t out_width = problem.GetInWidth_(); size_t in_lcl_width_effective = std::max(in_width + 2 * static_cast(problem.GetPadW()), @@ -587,7 +588,7 @@ ConvSolution ConvOclBwdWrW2::GetSolution( static_cast( std::ceil(static_cast(in_width) / config.read_size) * config.read_size), - static_cast(problem.GetWeightsWidth()) + + static_cast(problem.GetWeightsWidth_()) + (out_width - 1) * problem.GetKernelStrideW())); size_t in_lcl_width_right_buffer = std::max( @@ -607,7 +608,7 @@ ConvSolution ConvOclBwdWrW2::GetSolution( } size_t out_n_pixels_off = - problem.GetInWidth() - (problem.GetInWidth() / config.read_size) * config.read_size; + problem.GetInWidth_() - (problem.GetInWidth_() / config.read_size) * config.read_size; result.grp_tile0 = workgroup_size; result.grp_tile1 = 1; @@ -631,29 +632,29 @@ ConvSolution ConvOclBwdWrW2::GetSolution( std::to_string((result.grp_tile0)) + std::string(" -DMLO_GRP_SZ1=") + std::to_string((result.grp_tile1)) + std::string(" -DMLO_GRP_SZ2=") + std::to_string((grp_tile2)) + std::string(" -DMLO_FILTER_SIZE0=") + - std::to_string(problem.GetWeightsWidth()) + std::string(" -DMLO_FILTER_SIZE1=") + - std::to_string(problem.GetWeightsHeight()) + std::string(" -DMLO_FILTER_PAD0=") + + std::to_string(problem.GetWeightsWidth_()) + std::string(" -DMLO_FILTER_SIZE1=") + + std::to_string(problem.GetWeightsHeight_()) + std::string(" -DMLO_FILTER_PAD0=") + std::to_string(problem.GetPadW()) + std::string(" -DMLO_FILTER_PAD1=") + std::to_string(problem.GetPadH()) + std::string(" -DMLO_FILTER_STRIDE0=") + std::to_string(problem.GetKernelStrideW()) + std::string(" -DMLO_FILTER_STRIDE1=") + std::to_string(problem.GetKernelStrideH()) + std::string(" -DMLO_N_OUTPUTS=") + - std::to_string(problem.GetInChannels()) + std::string(" -DMLO_N_INPUTS=") + - std::to_string(problem.GetOutChannels()) + std::string(" -DMLO_BATCH_SZ=") + - std::to_string(problem.GetBatchSize()) + std::string(" -DMLO_N_BATCH_LOOPS=") + + std::to_string(problem.GetInChannels_()) + std::string(" -DMLO_N_INPUTS=") + + std::to_string(problem.GetOutChannels_()) + std::string(" -DMLO_BATCH_SZ=") + + std::to_string(problem.GetBatchSize_()) + std::string(" -DMLO_N_BATCH_LOOPS=") + std::to_string(N_BATCH_LOOPS) + std::string(" -DMLO_N_BATCH_BLKS=") + std::to_string(n_batch_blks) + std::string(" -DMLO_OUT_BATCH_STRIDE=") + - std::to_string((problem.GetInBatchStride())) + std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + - std::to_string((problem.GetInChannelStride())) + std::string(" -DMLO_OUT_STRIDE=") + - std::to_string((problem.GetInStride())) + std::string(" -DMLO_IN_BATCH_STRIDE=") + - std::to_string((problem.GetOutBatchStride())) + std::string(" -DMLO_IN_CHANNEL_STRIDE=") + - std::to_string((problem.GetOutChannelStride())) + std::string(" -DMLO_IN_STRIDE=") + - std::to_string((problem.GetOutStride())) + std::string(" -DMLO_WEI_BATCH_STRIDE=") + + std::to_string((problem.GetInBatchStride_())) + std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + + std::to_string((problem.GetInChannelStride_())) + std::string(" -DMLO_OUT_STRIDE=") + + std::to_string((problem.GetInStrideH_())) + std::string(" -DMLO_IN_BATCH_STRIDE=") + + std::to_string((problem.GetOutBatchStride_())) + std::string(" -DMLO_IN_CHANNEL_STRIDE=") + + std::to_string((problem.GetOutChannelStride_())) + std::string(" -DMLO_IN_STRIDE=") + + std::to_string((problem.GetOutStrideH_())) + std::string(" -DMLO_WEI_BATCH_STRIDE=") + std::to_string((wei_bstride)) + std::string(" -DMLO_WEI_CHANNEL_STRIDE=") + std::to_string((wei_cstride)) + std::string(" -DMLO_IN_WIDTH=") + - std::to_string((problem.GetOutWidth())) + std::string(" -DMLO_IN_HEIGHT=") + - std::to_string(problem.GetOutHeight()) + std::string(" -DMLO_OUT_WIDTH=") + - std::to_string(problem.GetInWidth()) + std::string(" -DMLO_OUT_HEIGHT=") + - std::to_string(problem.GetInHeight()) + std::string(" -DMLO_N_LCL_OUT_MAPS=") + + std::to_string((problem.GetOutWidth_())) + std::string(" -DMLO_IN_HEIGHT=") + + std::to_string(problem.GetOutHeight_()) + std::string(" -DMLO_OUT_WIDTH=") + + std::to_string(problem.GetInWidth_()) + std::string(" -DMLO_OUT_HEIGHT=") + + std::to_string(problem.GetInHeight_()) + std::string(" -DMLO_N_LCL_OUT_MAPS=") + std::to_string(config.n_out_channels_tiles) + // # output pixel tiles per wk-item (ALU) std::string(" -DMLO_N_LCL_IN_MAPS=") + std::to_string(result.n_in_data_tiles) + @@ -686,7 +687,7 @@ ConvSolution ConvOclBwdWrW2::GetSolution( kernel.l_wk.push_back(grp_tile2); assert(total_out_maps != 0); - size_t gbl_wk1 = std::ceil(static_cast(problem.GetInChannels()) / total_out_maps); + size_t gbl_wk1 = std::ceil(static_cast(problem.GetInChannels_()) / total_out_maps); size_t gbl_wk2 = n_batch_blks; size_t gbl_wk0 = workgroup_size; @@ -698,7 +699,7 @@ ConvSolution ConvOclBwdWrW2::GetSolution( } else { - gbl_wk0 *= problem.GetOutChannels(); + gbl_wk0 *= problem.GetOutChannels_(); kernel.kernel_file = "MIOpenConvBwdWrWS2.cl"; kernel.kernel_name = "MIOpenCvBwdWrW"; } @@ -726,8 +727,7 @@ ConvSolution ConvOclBwdWrW2::GetSolution( kernel.l_wk.push_back(1); assert(utility_read_unit != 0); - int gbl_ut_wk0 = static_cast(static_cast(wei_bstride) * problem.GetInChannels() / - utility_read_unit); + unsigned gbl_ut_wk0 = wei_bstride * problem.GetInChannels_() / utility_read_unit; kernel.g_wk.push_back(gbl_ut_wk0); kernel.g_wk.push_back(1); diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp index 0cde102649..e2ae607157 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp @@ -68,29 +68,30 @@ bool ConvOclBwdWrW53::IsApplicable(const ConvolutionContext& ctx, { // Workaround for issue 1173. These FP16 configs would cause clang-ocl compiler to crash // during kernel compilation, due to compiler bug - workaround = workaround || (problem.GetOutDataType() == miopenHalf && - ((problem.GetWeightsWidth() == 7 && - problem.GetWeightsHeight() == 7 && problem.GetPadW() == 3) || - (problem.GetWeightsWidth() == 7 && - problem.GetWeightsHeight() == 7 && problem.GetPadW() == 2) || - (problem.GetWeightsWidth() == 11 && - problem.GetWeightsHeight() == 11 && problem.GetPadW() == 5) || - (problem.GetWeightsWidth() == 11 && - problem.GetWeightsHeight() == 11 && problem.GetPadW() == 2) || - (problem.GetWeightsWidth() == 11 && - problem.GetWeightsHeight() == 11 && problem.GetPadW() == 1))); + workaround = + workaround || (problem.GetOutDataType() == miopenHalf && + ((problem.GetWeightsWidth_() == 7 && problem.GetWeightsHeight_() == 7 && + problem.GetPadW() == 3) || + (problem.GetWeightsWidth_() == 7 && problem.GetWeightsHeight_() == 7 && + problem.GetPadW() == 2) || + (problem.GetWeightsWidth_() == 11 && + problem.GetWeightsHeight_() == 11 && problem.GetPadW() == 5) || + (problem.GetWeightsWidth_() == 11 && + problem.GetWeightsHeight_() == 11 && problem.GetPadW() == 2) || + (problem.GetWeightsWidth_() == 11 && + problem.GetWeightsHeight_() == 11 && problem.GetPadW() == 1))); // Workaround for issue 1242. These FP32 configs produce wrong result if compiled with // OpenCL 1.2.0-2018090737 that comes with rocm 1.9, using -O2 flag or higher. // However, when compiled with older OpenCL that comes with rocm 1.8, this config // would pass - workaround = - workaround || (problem.GetOutDataType() == miopenFloat && - ((problem.GetWeightsWidth() == 7 && problem.GetWeightsHeight() == 7 && - problem.GetPadW() == 3) || - (problem.GetWeightsWidth() == 7 && problem.GetWeightsHeight() == 7 && - problem.GetPadW() == 1)) && - (problem.GetOutHeight() % 112 == 0 || problem.GetOutWidth() % 112 == 0)); + workaround = workaround || + (problem.GetOutDataType() == miopenFloat && + ((problem.GetWeightsWidth_() == 7 && problem.GetWeightsHeight_() == 7 && + problem.GetPadW() == 3) || + (problem.GetWeightsWidth_() == 7 && problem.GetWeightsHeight_() == 7 && + problem.GetPadW() == 1)) && + (problem.GetOutHeight_() % 112 == 0 || problem.GetOutWidth_() % 112 == 0)); // Workaround for issue 1479 // The compiler issue causes the correctness failure of particular config @@ -98,9 +99,9 @@ bool ConvOclBwdWrW53::IsApplicable(const ConvolutionContext& ctx, // Disabling compiler optimization i.e. #pragma unroll in MIOpenConvBwdWrW_LxG_P53.cl // restores the correctness. Until, the compiler issue is fixed, all configs with width 1024 // is skipped - workaround = workaround || (problem.IsFp32() && problem.GetWeightsWidth() == 3 && - problem.GetWeightsHeight() == 3 && problem.GetPadH() == 2 && - problem.GetPadW() == 2 && problem.GetOutWidth() == 1024); + workaround = workaround || (problem.IsFp32() && problem.GetWeightsWidth_() == 3 && + problem.GetWeightsHeight_() == 3 && problem.GetPadH() == 2 && + problem.GetPadW() == 2 && problem.GetOutWidth_() == 1024); } /// Resolve NaN issue on gfx908, manifested on Jenkins. @@ -108,20 +109,22 @@ bool ConvOclBwdWrW53::IsApplicable(const ConvolutionContext& ctx, /// performance and applicable for the affected "popular" configs (7x7 filter, 1x1 padding). const auto name = ctx.GetStream().GetDeviceName(); workaround = - workaround || (problem.IsFp16() && (name == "gfx908") && problem.GetWeightsWidth() == 7 && - problem.GetWeightsHeight() == 7 && problem.GetPadW() == 1); + workaround || (problem.IsFp16() && (name == "gfx908") && problem.GetWeightsWidth_() == 7 && + problem.GetWeightsHeight_() == 7 && problem.GetPadW() == 1); return (problem.GetDilationW() == 1 && problem.GetDilationH() == 1) && (problem.GetKernelStrideW() == 1 && problem.GetKernelStrideH() == 1) && // This limitation is because of the way the kernel process data at lower vertical // boundary (including padding). - (problem.GetWeightsHeight() >= problem.GetPadH() + problem.GetKernelStrideH()) && + (static_cast(problem.GetWeightsHeight_()) >= + problem.GetPadH() + problem.GetKernelStrideH()) && // Input image height plus vertical paddings should be no less than filter vertical size. // TODO: chao: revisit this to make sure this is the actual limitation. // Remind that input is output, output is input. - (problem.GetWeightsHeight() <= problem.GetOutHeight() + 2 * problem.GetPadH()) && + (static_cast(problem.GetWeightsHeight_()) <= + static_cast(problem.GetOutHeight_()) + 2 * problem.GetPadH()) && // Input and output width and height need to match exactly, // meaning, filter's moving range should be the same as input plus padding. @@ -129,10 +132,12 @@ bool ConvOclBwdWrW53::IsApplicable(const ConvolutionContext& ctx, // right padding, when reading an input row into LDS. Also need to rewrite the vertical // loop. // Remind that input is output, output is input. - (problem.GetInHeight() == - problem.GetOutHeight() + 2 * problem.GetPadH() - problem.GetWeightsHeight() + 1) && - (problem.GetInWidth() == - problem.GetOutWidth() + 2 * problem.GetPadW() - problem.GetWeightsWidth() + 1) && + (problem.GetInHeight_() == static_cast(problem.GetOutHeight_()) + + 2 * problem.GetPadH() - + static_cast(problem.GetWeightsHeight_()) + 1) && + (problem.GetInWidth_() == static_cast(problem.GetOutWidth_()) + + 2 * problem.GetPadW() - + static_cast(problem.GetWeightsWidth_()) + 1) && // Avoid LDS over-allocation GetSolution(ctx, problem).Succeeded() && !workaround; @@ -166,7 +171,7 @@ static inline miopenStatus_t ComputeInputParams( // As each width chunk starts to get split, // it should include complete kernel filter in horizontal span. - const unsigned filter_adjustment = problem.GetWeightsWidth() - 1; + const unsigned filter_adjustment = problem.GetWeightsWidth_() - 1; const auto lds_size = 64 * 1024; /// TBD Obtain this from device info. const auto max_lds_elements = @@ -177,14 +182,14 @@ static inline miopenStatus_t ComputeInputParams( { if(out_n_vert_reads < 2 && num_out_channels >= 2) { - out_n_vert_reads = problem.GetInHeight(); + out_n_vert_reads = problem.GetInHeight_(); num_out_channels = std::ceil(static_cast(num_out_channels) / 2); } - else if(out_n_vert_reads >= problem.GetWeightsHeight() * 2) + else if(out_n_vert_reads >= problem.GetWeightsHeight_() * 2) { out_n_vert_reads = std::ceil(static_cast(out_n_vert_reads) / 2); } - else if(out_n_vert_reads >= problem.GetWeightsHeight() && out_n_horizon_reads > 2) + else if(out_n_vert_reads >= problem.GetWeightsHeight_() && out_n_horizon_reads > 2) { out_n_horizon_reads = std::ceil(static_cast(out_n_horizon_reads) / 2); } @@ -198,9 +203,9 @@ static inline miopenStatus_t ComputeInputParams( // LDS check based on weight blob // Kernel uses LDS for storing input data and weight accumulation - if(workgroup_size * problem.GetWeightsWidth() > max_lds_elements) + if(workgroup_size * problem.GetWeightsWidth_() > max_lds_elements) { - MIOPEN_LOG_I2("For large filter size " << problem.GetWeightsWidth() + MIOPEN_LOG_I2("For large filter size " << problem.GetWeightsWidth_() << ", running out of LDS size (bytes) " << lds_size); return miopenStatusNotInitialized; } @@ -312,19 +317,20 @@ static inline void ComputeNumInputWidthLoops( size_t ConvOclBwdWrW53::GetWorkspaceSize(const ConvolutionContext&, const ProblemDescription& problem) const { - int n_stacks = std::min(problem.GetBatchSize(), 1); - int N_BATCH_LOOPS = (problem.GetInChannels() * problem.GetOutChannels() <= 8 * 1024) ? 1 - : (problem.GetBatchSize() <= 16 || problem.GetInWidth() <= 32) - ? (problem.GetBatchSize() / n_stacks) + int n_stacks = std::min(problem.GetBatchSize_(), 1U); + int N_BATCH_LOOPS = (problem.GetInChannels_() * problem.GetOutChannels_() <= 8 * 1024) ? 1 + : (problem.GetBatchSize_() <= 16 || problem.GetInWidth_() <= 32) + ? (problem.GetBatchSize_() / n_stacks) : 4; int n_batch_blks = - (problem.GetBatchSize() + N_BATCH_LOOPS * n_stacks - 1) / (N_BATCH_LOOPS * n_stacks); + (problem.GetBatchSize_() + N_BATCH_LOOPS * n_stacks - 1) / (N_BATCH_LOOPS * n_stacks); if(n_batch_blks > 1) { - int wei_bstride = (problem.GetOutChannels() / problem.GetGroupCount()) * - (problem.GetWeightsWidth() * problem.GetWeightsHeight()); + int wei_bstride = (problem.GetOutChannels_() / problem.GetGroupCount()) * + (problem.GetWeightsWidth_() * problem.GetWeightsHeight_()); int data_len = GetTypeSize(problem.GetOutDataType()); - return static_cast(wei_bstride) * problem.GetInChannels() * n_batch_blks * data_len; + return static_cast(wei_bstride) * problem.GetInChannels_() * n_batch_blks * + data_len; } else return 0; @@ -337,56 +343,58 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, const auto hw_wave_sz = 64; // inpout are outputs - int wei_cstride = problem.GetWeightsWidth() * problem.GetWeightsHeight(); + int wei_cstride = problem.GetWeightsWidth_() * problem.GetWeightsHeight_(); // At convolutionocl level, the assertion is present to ensure output channels are // in multiple of group counts - int wei_bstride = (problem.GetOutChannels() / problem.GetGroupCount()) * wei_cstride; + int wei_bstride = (problem.GetOutChannels_() / problem.GetGroupCount()) * wei_cstride; // number of batch iterations - result.n_stacks = 1; - result.n_stacks = std::min(problem.GetBatchSize(), result.n_stacks); + result.n_stacks = std::min(problem.GetBatchSize_(), 1U); // defines how to proceed : 1 grouop per batch or with a loop over all batches // loop over al batches make sense in 2 cases: a lot of small inputs/outputs or few batches - int N_BATCH_LOOPS = (problem.GetInChannels() * problem.GetOutChannels() <= 8 * 1024) ? 1 - : (problem.GetBatchSize() <= 16 || problem.GetInWidth() <= 32) - ? (problem.GetBatchSize() / result.n_stacks) + int N_BATCH_LOOPS = (problem.GetInChannels_() * problem.GetOutChannels_() <= 8 * 1024) ? 1 + : (problem.GetBatchSize_() <= 16 || problem.GetInWidth_() <= 32) + ? (problem.GetBatchSize_() / result.n_stacks) : 4; - int n_batch_blks = (problem.GetBatchSize() + N_BATCH_LOOPS * result.n_stacks - 1) / + int n_batch_blks = (problem.GetBatchSize_() + N_BATCH_LOOPS * result.n_stacks - 1) / (N_BATCH_LOOPS * result.n_stacks); - result.out_pix_tile0 = problem.GetWeightsWidth(); - result.out_pix_tile1 = problem.GetWeightsHeight(); + result.out_pix_tile0 = problem.GetWeightsWidth_(); + result.out_pix_tile1 = problem.GetWeightsHeight_(); // n of wavefronts per group int n_waves = - ((result.out_pix_tile0 * result.out_pix_tile1) <= 16 && (problem.GetInWidth() > 8)) ? 4 - : (problem.GetInWidth() <= 16) ? 1 - : 2; + ((result.out_pix_tile0 * result.out_pix_tile1) <= 16 && (problem.GetInWidth_() > 8)) ? 4 + : (problem.GetInWidth_() <= 16) ? 1 + : 2; int GRP_SZ = hw_wave_sz * n_waves; result.n_in_data_tiles = - (problem.GetInWidth() <= 32 && (result.out_pix_tile0 * result.out_pix_tile1) <= 16) ? 4 : 1; + (problem.GetInWidth_() <= 32 && (result.out_pix_tile0 * result.out_pix_tile1) <= 16) ? 4 + : 1; result.n_in_data_tiles = - std::min(result.n_in_data_tiles, (problem.GetOutChannels() / problem.GetGroupCount())); + std::min(result.n_in_data_tiles, + static_cast(problem.GetOutChannels_() / problem.GetGroupCount())); - static const int read_unit = (problem.GetOutWidth() % 4 == 0) ? 4 - : (problem.GetOutWidth() % 3 == 0) ? 3 - : (problem.GetOutWidth() % 2 == 0) ? 2 - : 1; + static const int read_unit = (problem.GetOutWidth_() % 4 == 0) ? 4 + : (problem.GetOutWidth_() % 3 == 0) ? 3 + : (problem.GetOutWidth_() % 2 == 0) ? 2 + : 1; static const std::string READ_TYPE = (read_unit == 1) ? "_FLOAT" : "_FLOAT" + std::to_string((read_unit)); // calculate number of input scans in the input block int out_lcl_width = - ((problem.GetOutWidth() + read_unit - 1) / read_unit) * read_unit + 2 * problem.GetPadW(); + ((static_cast(problem.GetOutWidth_()) + read_unit - 1) / read_unit) * read_unit + + 2 * problem.GetPadW(); // number of input map blocks being process at once - int out_n_vert_reads = (problem.GetOutHeight() > 32 && problem.GetOutWidth() <= 64 && + int out_n_vert_reads = (problem.GetOutHeight_() > 32 && problem.GetOutWidth_() <= 64 && (result.out_pix_tile0 * result.out_pix_tile1) <= 16) - ? (problem.GetOutHeight() + 1) / 2 - : problem.GetOutHeight(); + ? (problem.GetOutHeight_() + 1) / 2 + : problem.GetOutHeight_(); // Given the availability of LDS, recomputes the params int out_n_horizon_reads = out_lcl_width; @@ -402,20 +410,21 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, } int out_n_vert_read_loops = static_cast(std::ceil( - static_cast(problem.GetOutHeight()) / static_cast(out_n_vert_reads))); + static_cast(problem.GetOutHeight_()) / static_cast(out_n_vert_reads))); // When a row is split into chunks, each chunk should fully cover the entire filter in // horizontal dir - out_n_horizon_reads = (out_n_horizon_reads == out_lcl_width) - ? out_lcl_width - : (out_n_horizon_reads + problem.GetWeightsWidth() - 1); + out_n_horizon_reads = + (out_n_horizon_reads == out_lcl_width) + ? out_lcl_width + : (out_n_horizon_reads + static_cast(problem.GetWeightsWidth_()) - 1); int out_n_horizon_read_loops = 1; int out_horizon_last_chunk_valid_pixels = 0; ComputeNumInputWidthLoops(out_lcl_width, problem.GetPadW(), out_n_horizon_reads, - problem.GetWeightsWidth(), + problem.GetWeightsWidth_(), out_n_horizon_read_loops, out_horizon_last_chunk_valid_pixels); if(out_n_horizon_read_loops > 2 && problem.GetPadW() != 0) @@ -437,20 +446,20 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, : read_unit; // Compute in -> out in kernel i.e. dy - int in_width_chunk = - (out_n_horizon_read_loops == 1) - ? problem.GetInWidth() - : (out_n_horizon_reads + problem.GetPadW() - problem.GetWeightsWidth() + 1); + int in_width_chunk = (out_n_horizon_read_loops == 1) + ? problem.GetInWidth_() + : (out_n_horizon_reads + problem.GetPadW() - + static_cast(problem.GetWeightsWidth_()) + 1); int in_width_last_chunk_valid_pixels = - (out_n_horizon_read_loops == 1) ? 0 : (problem.GetInWidth() % in_width_chunk); + (out_n_horizon_read_loops == 1) ? 0 : (problem.GetInWidth_() % in_width_chunk); result.in_tile1 = 1; result.n_out_pix_tiles = 1; int n_out_stacks = 1; - ComputeOutputParams(problem.GetInWidth(), + ComputeOutputParams(problem.GetInWidth_(), GRP_SZ, in_width_chunk, - problem.GetInChannels(), + problem.GetInChannels_(), problem.GetGroupCount(), result.in_tile0, n_out_stacks); @@ -470,7 +479,7 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, // select output mapping int total_out_maps = result.n_out_pix_tiles * n_out_stacks; total_out_maps = - (total_out_maps > problem.GetInChannels()) ? problem.GetInChannels() : total_out_maps; + (total_out_maps > problem.GetInChannels_()) ? problem.GetInChannels_() : total_out_maps; result.grp_tile0 = GRP_SZ; result.grp_tile1 = 1; @@ -486,8 +495,8 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, (ut_read_unit == 1) ? "_FLOAT" : "_FLOAT" + std::to_string((ut_read_unit)); // group parameters - int n_input_channels_per_group = problem.GetOutChannels() / problem.GetGroupCount(); - int n_output_channels_per_group = problem.GetInChannels() / problem.GetGroupCount(); + int n_input_channels_per_group = problem.GetOutChannels_() / problem.GetGroupCount(); + int n_output_channels_per_group = problem.GetInChannels_() / problem.GetGroupCount(); if(!problem.direction.IsBackwardWrW()) MIOPEN_THROW("!problem.direction.IsBackwardWrW()"); @@ -497,33 +506,33 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, std::to_string(GRP_SZ) + std::string(" -DMLO_GRP_SZ0=") + std::to_string(result.grp_tile0) + std::string(" -DMLO_GRP_SZ1=") + std::to_string(result.grp_tile1) + std::string(" -DMLO_GRP_SZ2=") + std::to_string(grp_tile2) + - std::string(" -DMLO_FILTER_SIZE0=") + std::to_string(problem.GetWeightsWidth()) + - std::string(" -DMLO_FILTER_SIZE1=") + std::to_string(problem.GetWeightsHeight()) + + std::string(" -DMLO_FILTER_SIZE0=") + std::to_string(problem.GetWeightsWidth_()) + + std::string(" -DMLO_FILTER_SIZE1=") + std::to_string(problem.GetWeightsHeight_()) + std::string(" -DMLO_FILTER_PAD0=") + std::to_string(problem.GetPadW()) + std::string(" -DMLO_FILTER_PAD1=") + std::to_string(problem.GetPadH()) + std::string(" -DMLO_FILTER_STRIDE0=") + std::to_string(problem.GetKernelStrideW()) + std::string(" -DMLO_FILTER_STRIDE1=") + std::to_string(problem.GetKernelStrideH()) + std::string(" -DSTRIDE_W=") + std::to_string(problem.GetKernelStrideW()) + std::string(" -DSTRIDE_H=") + std::to_string(problem.GetKernelStrideH()) + - std::string(" -DMLO_N_OUTPUTS=") + std::to_string(problem.GetInChannels()) + - std::string(" -DMLO_N_INPUTS=") + std::to_string(problem.GetOutChannels()) + + std::string(" -DMLO_N_OUTPUTS=") + std::to_string(problem.GetInChannels_()) + + std::string(" -DMLO_N_INPUTS=") + std::to_string(problem.GetOutChannels_()) + std::string(" -DMLO_GROUP_COUNTS=") + std::to_string(problem.GetGroupCount()) + std::string(" -DMLO_N_INPUTS_PER_GROUP=") + std::to_string(n_input_channels_per_group) + std::string(" -DMLO_N_OUTPUTS_PER_GROUP=") + std::to_string(n_output_channels_per_group) + - std::string(" -DMLO_BATCH_SZ=") + std::to_string(problem.GetBatchSize()) + + std::string(" -DMLO_BATCH_SZ=") + std::to_string(problem.GetBatchSize_()) + std::string(" -DMLO_N_BATCH_LOOPS=") + std::to_string(N_BATCH_LOOPS) + - std::string(" -DMLO_OUT_BATCH_STRIDE=") + std::to_string(problem.GetInBatchStride()) + - std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + std::to_string(problem.GetInChannelStride()) + - std::string(" -DMLO_OUT_STRIDE=") + std::to_string(problem.GetInStride()) + - std::string(" -DMLO_IN_BATCH_STRIDE=") + std::to_string(problem.GetOutBatchStride()) + - std::string(" -DMLO_IN_CHANNEL_STRIDE=") + std::to_string(problem.GetOutChannelStride()) + - std::string(" -DMLO_IN_STRIDE=") + std::to_string(problem.GetOutStride()) + + std::string(" -DMLO_OUT_BATCH_STRIDE=") + std::to_string(problem.GetInBatchStride_()) + + std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + std::to_string(problem.GetInChannelStride_()) + + std::string(" -DMLO_OUT_STRIDE=") + std::to_string(problem.GetInStrideH_()) + + std::string(" -DMLO_IN_BATCH_STRIDE=") + std::to_string(problem.GetOutBatchStride_()) + + std::string(" -DMLO_IN_CHANNEL_STRIDE=") + std::to_string(problem.GetOutChannelStride_()) + + std::string(" -DMLO_IN_STRIDE=") + std::to_string(problem.GetOutStrideH_()) + std::string(" -DMLO_WEI_BATCH_STRIDE=") + std::to_string(wei_bstride) + std::string(" -DMLO_WEI_CHANNEL_STRIDE=") + std::to_string(wei_cstride) + - std::string(" -DMLO_IN_WIDTH=") + std::to_string(problem.GetOutWidth()) + - std::string(" -DMLO_IN_HEIGHT=") + std::to_string(problem.GetOutHeight()) + - std::string(" -DMLO_OUT_WIDTH=") + std::to_string(problem.GetInWidth()) + - std::string(" -DMLO_OUT_HEIGHT=") + std::to_string(problem.GetInHeight()) + + std::string(" -DMLO_IN_WIDTH=") + std::to_string(problem.GetOutWidth_()) + + std::string(" -DMLO_IN_HEIGHT=") + std::to_string(problem.GetOutHeight_()) + + std::string(" -DMLO_OUT_WIDTH=") + std::to_string(problem.GetInWidth_()) + + std::string(" -DMLO_OUT_HEIGHT=") + std::to_string(problem.GetInHeight_()) + std::string(" -DMLO_IN_TILE1=") + std::to_string(result.in_tile1) + std::string(" -DMLO_IN_TILE0=") + std::to_string(result.in_tile0) + std::string(" -DMLO_N_LCL_BATCHS=") + @@ -543,7 +552,7 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, std::string(" -DMLO_IN_EXTENT1=") + std::to_string(out_n_vert_reads) + std::string(" -DMLO_IN_N_VERT_LOOPS=") + std::to_string(out_n_vert_read_loops) + std::string(" -DMLO_IN_WIDTH_CHUNK=") + - std::to_string((out_n_horizon_read_loops == 1) ? problem.GetOutWidth() + std::to_string((out_n_horizon_read_loops == 1) ? problem.GetOutWidth_() : out_n_horizon_reads) + std::string(" -DMLO_IN_WIDTH_N_LOOPS=") + std::to_string(out_n_horizon_read_loops) + std::string(" -DMLO_IN_WIDTH_LAST_CHUNK_VALID_READ_UNITS=") + @@ -581,13 +590,13 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, kernel.l_wk.push_back(grp_tile2); // input is output - size_t gbl_wk1 = ((problem.GetInChannels() + total_out_maps - 1) / total_out_maps); + size_t gbl_wk1 = ((problem.GetInChannels_() + total_out_maps - 1) / total_out_maps); size_t gbl_wk2 = n_batch_blks; size_t gbl_wk0 = GRP_SZ; if(problem.GetGroupCount() > 1) { - gbl_wk0 *= (((problem.GetOutChannels() / problem.GetGroupCount()) + + gbl_wk0 *= (((problem.GetOutChannels_() / problem.GetGroupCount()) + result.n_in_data_tiles - 1) / result.n_in_data_tiles); @@ -597,7 +606,7 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, else { gbl_wk0 *= - ((problem.GetOutChannels() + result.n_in_data_tiles - 1) / result.n_in_data_tiles); + ((problem.GetOutChannels_() + result.n_in_data_tiles - 1) / result.n_in_data_tiles); kernel.kernel_file = "MIOpenConvBwdWrW_LxG_P53.cl"; kernel.kernel_name = "MIOpenCvBwdWrW"; @@ -624,7 +633,7 @@ ConvSolution ConvOclBwdWrW53::GetSolution(const ConvolutionContext& ctx, kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); - int gbl_ut_wk0 = wei_bstride * problem.GetInChannels() / ut_read_unit; + int gbl_ut_wk0 = wei_bstride * problem.GetInChannels_() / ut_read_unit; kernel.g_wk.push_back(gbl_ut_wk0); kernel.g_wk.push_back(1); diff --git a/src/solver/conv_ocl_dir2Dfwd.cpp b/src/solver/conv_ocl_dir2Dfwd.cpp index d388d1131a..18086410da 100644 --- a/src/solver/conv_ocl_dir2Dfwd.cpp +++ b/src/solver/conv_ocl_dir2Dfwd.cpp @@ -68,14 +68,14 @@ bool ConvOclDirectFwd::IsApplicable(const ConvolutionContext& ctx, { const auto& p = problem; //alias const bool supported = - ((p.GetWeightsHeight() == p.GetWeightsWidth()) - && (p.GetWeightsHeight() == 3 - || p.GetWeightsHeight() == 5 - || p.GetWeightsHeight() == 7 - || p.GetWeightsHeight() == 9 - || p.GetWeightsHeight() == 11)) - || ((p.GetWeightsWidth() == 10 || p.GetWeightsWidth() == 20) - && p.GetWeightsHeight() == 5 + ((p.GetWeightsHeight_() == p.GetWeightsWidth_()) + && (p.GetWeightsHeight_() == 3 + || p.GetWeightsHeight_() == 5 + || p.GetWeightsHeight_() == 7 + || p.GetWeightsHeight_() == 9 + || p.GetWeightsHeight_() == 11)) + || ((p.GetWeightsWidth_() == 10 || p.GetWeightsWidth_() == 20) + && p.GetWeightsHeight_() == 5 && p.GetKernelStrideH() == 2 && p.GetKernelStrideW() == 2 && p.GetPadH() == 0 @@ -83,8 +83,8 @@ bool ConvOclDirectFwd::IsApplicable(const ConvolutionContext& ctx, /// The following is for #1594. Most likely we can open more configs, /// but that would require thorough testing. || (p.IsFp16() - && p.GetWeightsHeight() == 4 - && p.GetWeightsWidth() == 4 + && p.GetWeightsHeight_() == 4 + && p.GetWeightsWidth_() == 4 && p.GetPadH() == 0 && p.GetPadW() == 0); @@ -98,7 +98,7 @@ bool ConvOclDirectFwd::IsApplicable(const ConvolutionContext& ctx, /// \todo need to make sure support stride > 2, should support but not tested && !(problem.GetKernelStrideW() > 2 || problem.GetKernelStrideH() > 2) /// We have optimized 1x1 kernel for normal conv. - && !(problem.GetGroupCount() == 1 && problem.GetWeightsHeight() == 1 && problem.GetWeightsWidth() == 1) + && !(problem.GetGroupCount() == 1 && problem.GetWeightsHeight_() == 1 && problem.GetWeightsWidth_() == 1) /// \todo Workaround to avoid FP16 precision issue: /// While MIOpenConvUni is up to 4x faster than MIOpenCDFGen (even not auto-tuned), /// it seems that is has 4x..20x worse precision, and some "test_conv --half" tests fail. @@ -128,14 +128,14 @@ bool ConvOclDirectFwd::IsValidPerformanceConfig(const ConvolutionContext&, // if(!problem.direction.IsForward()) // { // // backward - // pad_w = problem.GetWeightsWidth() - 1 - pad_w; - // pad_h = problem.GetWeightsHeight() - 1 - pad_h; + // pad_w = problem.GetBackwardPadW(); + // pad_h = problem.GetBackwardPadH(); // } auto group_counts = problem.GetGroupCount(); result.n_in_data_tiles = - std::min(problem.GetInChannels() / group_counts, config.n_in_data_tiles); - result.n_out_pix_tiles = - std::min(problem.GetOutChannels() / group_counts, config.n_out_pix_tiles); + std::min(static_cast(problem.GetInChannels_()) / group_counts, config.n_in_data_tiles); + result.n_out_pix_tiles = std::min(static_cast(problem.GetOutChannels_()) / group_counts, + config.n_out_pix_tiles); // hacky fix of the incorrect kernel local memory address calculation for data result.out_pix_tile1 = (!problem.direction.IsForward() && problem.GetKernelStrideH() > 1) @@ -166,7 +166,7 @@ bool ConvOclDirectFwd::IsValidPerformanceConfig(const ConvolutionContext&, int n_alus_total = (result.grp_tile0 * result.grp_tile1); result.n_stacks = std::min(result.n_stacks, (n_alus_total + alu_tiles_sz - 1) / alu_tiles_sz); - result.n_stacks = std::min(problem.GetBatchSize(), result.n_stacks); + result.n_stacks = std::min(static_cast(problem.GetBatchSize_()), result.n_stacks); if(result.n_stacks == 0 /* DIV/0 */) { @@ -190,32 +190,33 @@ bool ConvOclDirectFwd::IsValidPerformanceConfig(const ConvolutionContext&, // : (result.grp_tile1 * result.grp_tile0); // } - // int n_out_tile_blocks0 = (problem.GetOutWidth() + result.in_tile0 - 1) / (result.in_tile0); - // int n_out_tile_blocks1 = (problem.GetOutHeight() + result.in_tile1 - 1) / (result.in_tile1); + // int n_out_tile_blocks0 = (problem.GetOutWidth_() + result.in_tile0 - 1) / (result.in_tile0); + // int n_out_tile_blocks1 = (problem.GetOutHeight_() + result.in_tile1 - 1) / (result.in_tile1); int n_alu_tiles_perstack = (n_alus_perstack + alu_tiles_sz - 1) / alu_tiles_sz; int n_out_tiles_perstack = n_alu_tiles_perstack * result.n_out_pix_tiles; - n_out_tiles_perstack = std::min(n_out_tiles_perstack, problem.GetOutChannels() / group_counts); + n_out_tiles_perstack = + std::min(n_out_tiles_perstack, static_cast(problem.GetOutChannels_()) / group_counts); // const auto mlo_hw_wave_sz=hw_wave_sz; - const auto mlo_filter_size0 = static_cast(problem.GetWeightsWidth()); - const auto mlo_filter_size1 = static_cast(problem.GetWeightsHeight()); + const auto mlo_filter_size0 = static_cast(problem.GetWeightsWidth_()); + const auto mlo_filter_size1 = static_cast(problem.GetWeightsHeight_()); // const auto mlo_filter_pad0=static_cast(pad_w); // const auto mlo_filter_pad1=static_cast(pad_h); const auto mlo_filter_stride0 = static_cast(problem.GetKernelStrideW()); const auto mlo_filter_stride1 = static_cast(problem.GetKernelStrideH()); - // const auto mlo_n_outputs=static_cast(problem.GetOutChannels()); - // const auto mlo_n_inputs=static_cast(problem.GetInChannels()); - // const auto mlo_batch_sz=static_cast(problem.GetBatchSize()); - // const auto mlo_out_width=static_cast(problem.GetOutWidth()); - // const auto mlo_out_height=static_cast(problem.GetOutHeight()); - // const auto mlo_out_batch_stride=static_cast(problem.GetOutBatchStride()); - // const auto mlo_out_channel_stride=static_cast(problem.GetOutChannelStride()); - // const auto mlo_out_stride=static_cast(problem.GetOutStride()); - // const auto mlo_in_width=static_cast(problem.GetInWidth()); - // const auto mlo_in_height=static_cast(problem.GetInHeight()); - // const auto mlo_in_batch_stride=static_cast(problem.GetInBatchStride()); - // const auto mlo_in_channel_stride=static_cast(problem.GetInChannelStride()); - // const auto mlo_in_stride=static_cast(problem.GetInStride()); + // const auto mlo_n_outputs=static_cast(problem.GetOutChannels_()); + // const auto mlo_n_inputs=static_cast(problem.GetInChannels_()); + // const auto mlo_batch_sz=static_cast(problem.GetBatchSize_()); + // const auto mlo_out_width=static_cast(problem.GetOutWidth_()); + // const auto mlo_out_height=static_cast(problem.GetOutHeight_()); + // const auto mlo_out_batch_stride=static_cast(problem.GetOutBatchStride_()); + // const auto mlo_out_channel_stride=static_cast(problem.GetOutChannelStride_()); + // const auto mlo_out_stride=static_cast(problem.GetOutStrideH_()); + // const auto mlo_in_width=static_cast(problem.GetInWidth_()); + // const auto mlo_in_height=static_cast(problem.GetInHeight_()); + // const auto mlo_in_batch_stride=static_cast(problem.GetInBatchStride_()); + // const auto mlo_in_channel_stride=static_cast(problem.GetInChannelStride_()); + // const auto mlo_in_stride=static_cast(problem.GetInStrideH_()); // algorithm parameters const auto mlo_in_tile0 = static_cast(result.in_tile0); const auto mlo_in_tile1 = static_cast(result.in_tile1); @@ -290,14 +291,14 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ConvolutionContext& ctx, if(!problem.direction.IsForward()) { // backward - pad_w = problem.GetWeightsWidth() - 1 - pad_w; - pad_h = problem.GetWeightsHeight() - 1 - pad_h; + pad_w = problem.GetBackwardPadW(); + pad_h = problem.GetBackwardPadH(); } result.n_in_data_tiles = - std::min(problem.GetInChannels() / group_counts, config.n_in_data_tiles); - result.n_out_pix_tiles = - std::min(problem.GetOutChannels() / group_counts, config.n_out_pix_tiles); + std::min(static_cast(problem.GetInChannels_()) / group_counts, config.n_in_data_tiles); + result.n_out_pix_tiles = std::min(static_cast(problem.GetOutChannels_()) / group_counts, + config.n_out_pix_tiles); // hacky fix of the incorrect kernel local memory address calculation for data result.out_pix_tile1 = (!problem.direction.IsForward() && problem.GetKernelStrideH() > 1) @@ -329,7 +330,7 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ConvolutionContext& ctx, int n_alus_total = (result.grp_tile0 * result.grp_tile1); result.n_stacks = std::min(result.n_stacks, (n_alus_total + alu_tiles_sz - 1) / alu_tiles_sz); - result.n_stacks = std::min(problem.GetBatchSize(), result.n_stacks); + result.n_stacks = std::min(static_cast(problem.GetBatchSize_()), result.n_stacks); if(result.n_stacks == 0 /* DIV/0 */) { @@ -352,13 +353,14 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ConvolutionContext& ctx, : (result.grp_tile1 * result.grp_tile0); } - int n_out_tile_blocks0 = (problem.GetOutWidth() + result.in_tile0 - 1) / (result.in_tile0); - int n_out_tile_blocks1 = (problem.GetOutHeight() + result.in_tile1 - 1) / (result.in_tile1); + int n_out_tile_blocks0 = (problem.GetOutWidth_() + result.in_tile0 - 1) / (result.in_tile0); + int n_out_tile_blocks1 = (problem.GetOutHeight_() + result.in_tile1 - 1) / (result.in_tile1); int n_alu_tiles_perstack = (n_alus_perstack + alu_tiles_sz - 1) / alu_tiles_sz; int n_out_tiles_perstack = n_alu_tiles_perstack * result.n_out_pix_tiles; - n_out_tiles_perstack = std::min(n_out_tiles_perstack, problem.GetOutChannels() / group_counts); + n_out_tiles_perstack = + std::min(n_out_tiles_perstack, static_cast(problem.GetOutChannels_()) / group_counts); KernelInfo kernel_params; @@ -366,9 +368,9 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ConvolutionContext& ctx, std::string(" -DMLO_HW_WAVE_SZ=") + std::to_string(static_cast(hw_wave_sz)) + std::string(" -DMLO_DIR_FORWARD=") + (problem.direction.IsForward() ? "1" : "0") + std::string(" -DMLO_FILTER_SIZE0=") + - std::to_string(static_cast(problem.GetWeightsWidth())) + + std::to_string(static_cast(problem.GetWeightsWidth_())) + std::string(" -DMLO_FILTER_SIZE1=") + - std::to_string(static_cast(problem.GetWeightsHeight())) + + std::to_string(static_cast(problem.GetWeightsHeight_())) + std::string(" -DMLO_FILTER_PAD0=") + std::to_string(static_cast(pad_w)) + std::string(" -DMLO_FILTER_PAD1=") + std::to_string(static_cast(pad_h)) + std::string(" -DMLO_FILTER_STRIDE0=") + @@ -376,31 +378,31 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ConvolutionContext& ctx, std::string(" -DMLO_FILTER_STRIDE1=") + std::to_string(static_cast(problem.GetKernelStrideH())) + std::string(" -DMLO_N_OUTPUTS=") + - std::to_string(static_cast(problem.GetOutChannels())) + + std::to_string(static_cast(problem.GetOutChannels_())) + std::string(" -DMLO_N_INPUTS=") + - std::to_string(static_cast(problem.GetInChannels())) + + std::to_string(static_cast(problem.GetInChannels_())) + std::string(" -DMLO_BATCH_SZ=") + - std::to_string(static_cast(problem.GetBatchSize())) + + std::to_string(static_cast(problem.GetBatchSize_())) + std::string(" -DMLO_OUT_WIDTH=") + - std::to_string(static_cast(problem.GetOutWidth())) + + std::to_string(static_cast(problem.GetOutWidth_())) + std::string(" -DMLO_OUT_HEIGHT=") + - std::to_string(static_cast(problem.GetOutHeight())) + + std::to_string(static_cast(problem.GetOutHeight_())) + std::string(" -DMLO_OUT_BATCH_STRIDE=") + - std::to_string(static_cast(problem.GetOutBatchStride())) + + std::to_string(static_cast(problem.GetOutBatchStride_())) + std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + - std::to_string(static_cast(problem.GetOutChannelStride())) + + std::to_string(static_cast(problem.GetOutChannelStride_())) + std::string(" -DMLO_OUT_STRIDE=") + - std::to_string(static_cast(problem.GetOutStride())) + + std::to_string(static_cast(problem.GetOutStrideH_())) + std::string(" -DMLO_IN_WIDTH=") + - std::to_string(static_cast(problem.GetInWidth())) + + std::to_string(static_cast(problem.GetInWidth_())) + std::string(" -DMLO_IN_HEIGHT=") + - std::to_string(static_cast(problem.GetInHeight())) + + std::to_string(static_cast(problem.GetInHeight_())) + std::string(" -DMLO_IN_BATCH_STRIDE=") + - std::to_string(static_cast(problem.GetInBatchStride())) + + std::to_string(static_cast(problem.GetInBatchStride_())) + std::string(" -DMLO_IN_CHANNEL_STRIDE=") + - std::to_string(static_cast(problem.GetInChannelStride())) + + std::to_string(static_cast(problem.GetInChannelStride_())) + std::string(" -DMLO_IN_STRIDE=") + - std::to_string(static_cast(problem.GetInStride())) + std::to_string(static_cast(problem.GetInStrideH_())) // algorithm parameters + std::string(" -DMLO_IN_TILE0=") + std::to_string(static_cast(result.in_tile0)) // size of input data per ALU plane @@ -436,11 +438,11 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ConvolutionContext& ctx, std::to_string(static_cast(group_counts))); kernel_params.comp_options += (std::string(" -DMLO_GROUP_TILES=") + - std::to_string(static_cast(problem.GetOutChannels() / group_counts))); + std::to_string(static_cast(problem.GetOutChannels_() / group_counts))); kernel_params.comp_options += (std::string(" -DMLO_STACK_PERGROUP=") + std::to_string(static_cast( - (problem.GetOutChannels() / group_counts + n_out_tiles_perstack - 1) / + (problem.GetOutChannels_() / group_counts + n_out_tiles_perstack - 1) / n_out_tiles_perstack))); kernel_params.comp_options += std::string(" -DGRP_MOD_ENABLE"); } @@ -458,11 +460,11 @@ ConvSolution ConvOclDirectFwd::BaseGetSolution(const ConvolutionContext& ctx, } size_t gbl_wk1 = group_counts >= 2 - ? (((problem.GetOutChannels() / group_counts + n_out_tiles_perstack - 1) / + ? (((problem.GetOutChannels_() / group_counts + n_out_tiles_perstack - 1) / n_out_tiles_perstack) * group_counts) - : ((problem.GetOutChannels() + n_out_tiles_perstack - 1) / n_out_tiles_perstack); - size_t gbl_wk2 = (problem.GetBatchSize() + result.n_stacks - 1) / result.n_stacks; + : ((problem.GetOutChannels_() + n_out_tiles_perstack - 1) / n_out_tiles_perstack); + size_t gbl_wk2 = (problem.GetBatchSize_() + result.n_stacks - 1) / result.n_stacks; kernel_params.g_wk.push_back(gbl_wk0 * kernel_params.l_wk[0]); kernel_params.g_wk.push_back(gbl_wk1); diff --git a/src/solver/conv_ocl_dir2Dfwd1x1.cpp b/src/solver/conv_ocl_dir2Dfwd1x1.cpp index d11eaddc23..9c6f392821 100644 --- a/src/solver/conv_ocl_dir2Dfwd1x1.cpp +++ b/src/solver/conv_ocl_dir2Dfwd1x1.cpp @@ -67,7 +67,7 @@ bool ConvOclDirectFwd1x1::IsApplicable(const ConvolutionContext& ctx, } return problem.GetDilationW() == 1 && problem.GetDilationH() == 1 && - problem.GetWeightsWidth() == 1 && problem.GetWeightsHeight() == 1 && + problem.GetWeightsWidth_() == 1 && problem.GetWeightsHeight_() == 1 && problem.GetGroupCount() == 1 && // TODO: update 1x1 fwd kernel to support padding problem.GetPadW() == 0 && problem.GetPadH() == 0; @@ -80,12 +80,12 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ConvolutionContext& ctx, ConvSolution result; config.CopyTo(result); - // if(problem.GetOutChannels() % 4 == 0 && problem.GetInChannels() % 4 == 0) + // if(problem.GetOutChannels_() % 4 == 0 && problem.GetInChannels_() % 4 == 0) { // int version = result.out_pix_tile1; - if((problem.direction.IsForward() && problem.GetInChannels() % 16 == 0 && - problem.GetOutChannels() % 16 == 0) && + if((problem.direction.IsForward() && problem.GetInChannels_() % 16 == 0 && + problem.GetOutChannels_() % 16 == 0) && (problem.GetInDataType() == miopenFloat)) { @@ -95,13 +95,13 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ConvolutionContext& ctx, // 0 or 1 int CHEAT_SHADER_COMPILER = result.out_pix_tile0; - int BATCHSIZE = problem.GetBatchSize(); - int W = problem.GetInWidth(); - int H = problem.GetInHeight(); - int C = problem.GetInChannels(); - int K = problem.GetOutChannels(); - int W_out = problem.GetOutWidth(); - int H_out = problem.GetOutHeight(); + int BATCHSIZE = problem.GetBatchSize_(); + int W = problem.GetInWidth_(); + int H = problem.GetInHeight_(); + int C = problem.GetInChannels_(); + int K = problem.GetOutChannels_(); + int W_out = problem.GetOutWidth_(); + int H_out = problem.GetOutHeight_(); N_LCL_IN_MAPS = std::min(N_LCL_IN_MAPS, C); N_LCL_OUT_MAPS = std::min(N_LCL_OUT_MAPS, K); @@ -198,8 +198,8 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ConvolutionContext& ctx, kernel.l_wk.push_back(local_wk1); kernel.l_wk.push_back(1); - size_t imagesizeAlign = ((static_cast(problem.GetOutWidth()) * - problem.GetOutHeight() * problem.GetBatchSize() + + size_t imagesizeAlign = ((static_cast(problem.GetOutWidth_()) * + problem.GetOutHeight_() * problem.GetBatchSize_() + FIXED_WORKGROUP_SIZE - 1) / FIXED_WORKGROUP_SIZE) * FIXED_WORKGROUP_SIZE; @@ -224,8 +224,8 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ConvolutionContext& ctx, kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); - size_t imagesizeAlign = ((static_cast(problem.GetInWidth()) * - problem.GetInHeight() * problem.GetBatchSize() + + size_t imagesizeAlign = ((static_cast(problem.GetInWidth_()) * + problem.GetInHeight_() * problem.GetBatchSize_() + FIXED_WORKGROUP_SIZE - 1) / FIXED_WORKGROUP_SIZE) * FIXED_WORKGROUP_SIZE; @@ -251,46 +251,49 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ConvolutionContext& ctx, { // parameters - // int i_sz = problem.GetInWidth() * problem.GetInHeight(); + // int i_sz = problem.GetInWidth_() * problem.GetInHeight_(); // _out_pix_tile0 = (i_sz & 1) ? 1 : 2; - result.out_pix_tile0 = std::min(problem.GetOutWidth(), result.out_pix_tile0); - result.out_pix_tile1 = std::min(problem.GetOutHeight(), result.out_pix_tile1); + result.out_pix_tile0 = + std::min(static_cast(problem.GetOutWidth_()), result.out_pix_tile0); + result.out_pix_tile1 = + std::min(static_cast(problem.GetOutHeight_()), result.out_pix_tile1); if(!problem.direction.IsForward()) { - while(problem.GetOutWidth() % result.out_pix_tile0 != 0 && result.out_pix_tile0 > 1) + while(problem.GetOutWidth_() % result.out_pix_tile0 != 0 && + result.out_pix_tile0 > 1) { result.out_pix_tile0 /= 2; } } int read_unit = result.out_pix_tile0; - while(problem.GetInWidth() % read_unit != 0 && read_unit > 1) + while(problem.GetInWidth_() % read_unit != 0 && read_unit > 1) { read_unit /= 2; } - // problem.GetOutWidth() + // problem.GetOutWidth_() // _n_out_pix_tiles = 16; // _n_in_data_tiles = 4; // _grp_tile0 = 64; - int wei_cstride = problem.GetWeightsWidth() * problem.GetWeightsHeight(); + int wei_cstride = problem.GetWeightsWidth_() * problem.GetWeightsHeight_(); // backward: inputs are forward outputs const bool is_forward = problem.direction.IsForward(); int wei_bstride = - (is_forward ? problem.GetInChannels() : problem.GetOutChannels()) * wei_cstride; + (is_forward ? problem.GetInChannels_() : problem.GetOutChannels_()) * wei_cstride; - int OUT_WIDTH4 = problem.GetOutWidth(); - int MAP_SZ4 = (OUT_WIDTH4 * problem.GetOutHeight() + read_unit - 1) / (read_unit); + int OUT_WIDTH4 = problem.GetOutWidth_(); + int MAP_SZ4 = (OUT_WIDTH4 * problem.GetOutHeight_() + read_unit - 1) / (read_unit); // stride > 1 and/or apdding if(problem.GetPadW() > 0 || problem.GetKernelStrideW() > 1 || problem.GetPadH() > 0 || problem.GetKernelStrideH() > 1) { int step = is_forward ? read_unit : read_unit * problem.GetKernelStrideW(); - OUT_WIDTH4 = (problem.GetOutWidth() + step - 1) / (step); + OUT_WIDTH4 = (problem.GetOutWidth_() + step - 1) / (step); int OUT_HEIGHT4 = is_forward - ? problem.GetOutHeight() - : (problem.GetOutHeight() + problem.GetKernelStrideH() - 1) / + ? problem.GetOutHeight_() + : (problem.GetOutHeight_() + problem.GetKernelStrideH() - 1) / problem.GetKernelStrideH(); MAP_SZ4 = (OUT_WIDTH4 * OUT_HEIGHT4); } @@ -300,11 +303,11 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ConvolutionContext& ctx, if(!is_forward) { VERT_ALIGNED = - (problem.GetOutHeight() / problem.GetKernelStrideH() == problem.GetInHeight()) + (problem.GetOutHeight_() / problem.GetKernelStrideH() == problem.GetInHeight_()) ? 1 : 0; HORIZ_ALIGNED = - (problem.GetOutWidth() / problem.GetKernelStrideW() == problem.GetInWidth()) + (problem.GetOutWidth_() / problem.GetKernelStrideW() == problem.GetInWidth_()) ? 1 : 0; } @@ -312,19 +315,21 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ConvolutionContext& ctx, int GRP_SZ = result.grp_tile0; // number of inputs inside wk-items - result.n_in_data_tiles = std::min(problem.GetInChannels(), result.n_in_data_tiles); - while(problem.GetInChannels() % result.n_in_data_tiles != 0 && + result.n_in_data_tiles = + std::min(static_cast(problem.GetInChannels_()), result.n_in_data_tiles); + while(problem.GetInChannels_() % result.n_in_data_tiles != 0 && result.n_in_data_tiles > 1) { result.n_in_data_tiles /= 2; } int CLOOP0 = - (problem.GetInChannels() + result.n_in_data_tiles - 1) / result.n_in_data_tiles; + (problem.GetInChannels_() + result.n_in_data_tiles - 1) / result.n_in_data_tiles; // number of outputs inside wk_item - result.n_out_pix_tiles = std::min(problem.GetOutChannels(), result.n_out_pix_tiles); - while(problem.GetOutChannels() % result.n_out_pix_tiles != 0 && + result.n_out_pix_tiles = + std::min(static_cast(problem.GetOutChannels_()), result.n_out_pix_tiles); + while(problem.GetOutChannels_() % result.n_out_pix_tiles != 0 && result.n_out_pix_tiles > 1) { result.n_out_pix_tiles /= 2; @@ -334,28 +339,28 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ConvolutionContext& ctx, kernel.comp_options = std::string(" -DMLO_DIR_FORWARD=") + (is_forward ? "1" : "0") + - std::string(" -DMLO_FILTER_SIZE0=") + std::to_string(problem.GetWeightsWidth()) + - std::string(" -DMLO_FILTER_SIZE1=") + std::to_string(problem.GetWeightsHeight()) + + std::string(" -DMLO_FILTER_SIZE0=") + std::to_string(problem.GetWeightsWidth_()) + + std::string(" -DMLO_FILTER_SIZE1=") + std::to_string(problem.GetWeightsHeight_()) + std::string(" -DMLO_FILTER_STRIDE0=") + std::to_string(problem.GetKernelStrideW()) + std::string(" -DMLO_FILTER_STRIDE1=") + std::to_string(problem.GetKernelStrideH()) + std::string(" -DMLO_FILTER_PAD0=") + std::to_string(problem.GetPadW()) + std::string(" -DMLO_FILTER_PAD1=") + std::to_string(problem.GetPadH()) + - std::string(" -DMLO_IN_WIDTH=") + std::to_string(problem.GetInWidth()) + - std::string(" -DMLO_IN_HEIGHT=") + std::to_string(problem.GetInHeight()) + - std::string(" -DMLO_OUT_WIDTH=") + std::to_string(problem.GetOutWidth()) + - std::string(" -DMLO_OUT_HEIGHT=") + std::to_string(problem.GetOutHeight()) + - std::string(" -DMLO_N_OUTPUTS=") + std::to_string(problem.GetOutChannels()) + - std::string(" -DMLO_N_INPUTS=") + std::to_string(problem.GetInChannels()) + - std::string(" -DMLO_BATCH_SZ=") + std::to_string(problem.GetBatchSize()) + + std::string(" -DMLO_IN_WIDTH=") + std::to_string(problem.GetInWidth_()) + + std::string(" -DMLO_IN_HEIGHT=") + std::to_string(problem.GetInHeight_()) + + std::string(" -DMLO_OUT_WIDTH=") + std::to_string(problem.GetOutWidth_()) + + std::string(" -DMLO_OUT_HEIGHT=") + std::to_string(problem.GetOutHeight_()) + + std::string(" -DMLO_N_OUTPUTS=") + std::to_string(problem.GetOutChannels_()) + + std::string(" -DMLO_N_INPUTS=") + std::to_string(problem.GetInChannels_()) + + std::string(" -DMLO_BATCH_SZ=") + std::to_string(problem.GetBatchSize_()) + std::string(" -DMLO_OUT_BATCH_STRIDE=") + - std::to_string(problem.GetOutBatchStride()) + + std::to_string(problem.GetOutBatchStride_()) + std::string(" -DMLO_OUT_CHANNEL_STRIDE=") + - std::to_string(problem.GetOutChannelStride()) + std::string(" -DMLO_OUT_STRIDE=") + - std::to_string(problem.GetOutStride()) + std::string(" -DMLO_IN_BATCH_STRIDE=") + - std::to_string(problem.GetInBatchStride()) + + std::to_string(problem.GetOutChannelStride_()) + std::string(" -DMLO_OUT_STRIDE=") + + std::to_string(problem.GetOutStrideH_()) + std::string(" -DMLO_IN_BATCH_STRIDE=") + + std::to_string(problem.GetInBatchStride_()) + std::string(" -DMLO_IN_CHANNEL_STRIDE=") + - std::to_string(problem.GetInChannelStride()) + std::string(" -DMLO_IN_STRIDE=") + - std::to_string(problem.GetInStride()) + std::string(" -DMLO_WEI_BSTRIDE=") + + std::to_string(problem.GetInChannelStride_()) + std::string(" -DMLO_IN_STRIDE=") + + std::to_string(problem.GetInStrideH_()) + std::string(" -DMLO_WEI_BSTRIDE=") + std::to_string(wei_bstride) + std::string(" -DMLO_WEI_CHANNEL_STRIDE=") + std::to_string(wei_cstride) + // algorithm parameters @@ -387,10 +392,10 @@ ConvSolution ConvOclDirectFwd1x1::GetSolution(const ConvolutionContext& ctx, kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); - size_t gbl_wk0 = static_cast(problem.GetBatchSize()) * MAP_SZ4; + size_t gbl_wk0 = static_cast(problem.GetBatchSize_()) * MAP_SZ4; size_t gbl_wk1 = - (problem.GetOutChannels() + result.n_out_pix_tiles - 1) / result.n_out_pix_tiles; + (problem.GetOutChannels_() + result.n_out_pix_tiles - 1) / result.n_out_pix_tiles; size_t gbl_wk2 = 1; kernel.g_wk.push_back(gbl_wk0); diff --git a/src/solver/conv_ocl_dir2Dfwd_exhaustive_search.cpp b/src/solver/conv_ocl_dir2Dfwd_exhaustive_search.cpp index 4d8788ea0f..7d5d320b81 100644 --- a/src/solver/conv_ocl_dir2Dfwd_exhaustive_search.cpp +++ b/src/solver/conv_ocl_dir2Dfwd_exhaustive_search.cpp @@ -58,12 +58,12 @@ LegacyPerformanceConfig ConvOclDirectFwdLegacyExhaustiveSearch::GetDefaultPerfor { // LegacyPerformanceConfig result{}; - result.in_tile0 = (problem.GetInWidth() <= 8) ? 8 - : (problem.GetInWidth() <= 16) ? 16 - : 32; // size of input data per ALU plane - result.in_tile1 = (problem.GetInHeight() <= 8) ? 8 - : (problem.GetInHeight() <= 16) ? 16 + result.in_tile0 = (problem.GetInWidth_() <= 8) ? 8 + : (problem.GetInWidth_() <= 16) ? 16 : 32; // size of input data per ALU plane + result.in_tile1 = (problem.GetInHeight_() <= 8) ? 8 + : (problem.GetInHeight_() <= 16) ? 16 + : 32; // size of input data per ALU plane result.out_pix_tile0 = std::max(problem.GetKernelStrideW(), @@ -81,14 +81,14 @@ LegacyPerformanceConfig ConvOclDirectFwdLegacyExhaustiveSearch::GetDefaultPerfor result.n_stacks = 1; // # of diff stacks (part of batch). - if(problem.GetWeightsWidth() == 1 && problem.GetWeightsHeight() == 1 && + if(problem.GetWeightsWidth_() == 1 && problem.GetWeightsHeight_() == 1 && problem.GetGroupCount() == 1) // Group conv: None 1x1 version yet, fallback to universal kernel. { // version if(problem.GetInDataType() == miopenFloat && problem.direction.IsForward() && - problem.GetInChannels() % 16 == 0 && problem.GetOutChannels() % 16 == 0) + problem.GetInChannels_() % 16 == 0 && problem.GetOutChannels_() % 16 == 0) { result.n_in_data_tiles = 128; @@ -102,19 +102,19 @@ LegacyPerformanceConfig ConvOclDirectFwdLegacyExhaustiveSearch::GetDefaultPerfor } else { - int i_sz = problem.GetOutHeight() * problem.GetOutWidth(); + int i_sz = problem.GetOutHeight_() * problem.GetOutWidth_(); result.out_pix_tile0 = (i_sz & 1) != 0 ? 1 : 2; if(problem.GetPadW() > 0 || problem.GetKernelStrideW() > 1) { if(problem.direction.IsForward()) { - result.out_pix_tile0 = (problem.GetOutWidth() & 1) != 0 ? 1 : 2; + result.out_pix_tile0 = (problem.GetOutWidth_() & 1) != 0 ? 1 : 2; } else { result.out_pix_tile0 = - (((problem.GetOutWidth() & 1) != 0) || ((problem.GetInWidth() & 1) != 0)) + (((problem.GetOutWidth_() & 1) != 0) || ((problem.GetInWidth_() & 1) != 0)) ? 1 : 2; } @@ -324,18 +324,18 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx int out_pix_tl_cnt = 3; // out_pix_tile_sz[1]; int n_out_tls = 4; int n_in_tls = 3; - int stack_cnt = std::min(problem.GetBatchSize(), 2); + int stack_cnt = std::min(problem.GetBatchSize_(), 2U); int n_tile0_sz = 4; int n_tile1_sz = 4; - if(problem.GetOutWidth() >= 16) + if(problem.GetOutWidth_() >= 16) { tile_sz0[0] = 16; tile_sz0[1] = 32; n_tile0_sz = 2; } - if(problem.GetOutHeight() >= 16) + if(problem.GetOutHeight_() >= 16) { tile_sz1[0] = 16; tile_sz1[1] = 32; @@ -348,7 +348,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx long long runs_left = 0, total_runs = 0; - if(problem.GetWeightsWidth() == 1 && problem.GetWeightsHeight() == 1 && + if(problem.GetWeightsWidth_() == 1 && problem.GetWeightsHeight_() == 1 && problem.GetGroupCount() == 1) // Group conv: None 1x1 version yet, fallback to universal kernel. { @@ -361,7 +361,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx // Add 1x1_stride : no padding support yet if(problem.GetInDataType() == miopenFloat && problem.direction.IsForward() && - problem.GetInChannels() % 16 == 0 && problem.GetOutChannels() % 16 == 0) + problem.GetInChannels_() % 16 == 0 && problem.GetOutChannels_() % 16 == 0) { // unsigned N_LCL_IN_MAPS = result.n_in_data_tiles; @@ -389,7 +389,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx } else { - int i_sz = problem.GetInWidth() * problem.GetInHeight(); + int i_sz = problem.GetInWidth_() * problem.GetInHeight_(); if(problem.GetKernelStrideW() == 1) { out_pix_tl_cnt = (i_sz & 1) != 0 ? 1 : (i_sz & 0x3) != 0 ? 2 : 3; @@ -398,12 +398,12 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx { if(problem.direction.IsForward()) { - out_pix_tl_cnt = (problem.GetOutWidth() & 1) != 0 ? 1 : 2; + out_pix_tl_cnt = (problem.GetOutWidth_() & 1) != 0 ? 1 : 2; } else { out_pix_tl_cnt = - (((problem.GetOutWidth() & 1) != 0) || ((problem.GetInWidth() & 1) != 0)) + (((problem.GetOutWidth_() & 1) != 0) || ((problem.GetInWidth_() & 1) != 0)) ? 1 : 2; } @@ -413,12 +413,12 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx out_pix_tile_sz[2] = 4; n_out_tiles_rg[0] = 2; - n_out_tiles_rg[1] = (problem.GetOutChannels() % 64 == 0) ? 6 - : (problem.GetOutChannels() % 32 == 0) ? 5 - : 4; + n_out_tiles_rg[1] = (problem.GetOutChannels_() % 64 == 0) ? 6 + : (problem.GetOutChannels_() % 32 == 0) ? 5 + : 4; n_in_tiles_rg[0] = 2; - n_in_tiles_rg[1] = (problem.GetInChannels() % 8 == 0) ? 3 : 2; + n_in_tiles_rg[1] = (problem.GetInChannels_() % 8 == 0) ? 3 : 2; grp_tl_ln[0] = 64; grp_tl_ln[1] = 128; @@ -522,7 +522,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx { int tile_sz[3] = {8, 16, 32}; result.in_tile1 = tile_sz1[j]; - if(problem.GetOutHeight() * 2 <= result.in_tile1 && result.in_tile1 > tile_sz[0]) + if(problem.GetOutHeight_() * 2 <= result.in_tile1 && result.in_tile1 > tile_sz[0]) { --runs_left; continue; @@ -532,19 +532,19 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx for(int i = 0; i < n_tile0_sz; ++i) { result.in_tile0 = tile_sz0[i]; - if((problem.GetOutWidth() * 2 <= result.in_tile0 && result.in_tile0 > tile_sz[0])) + if((problem.GetOutWidth_() * 2 <= result.in_tile0 && result.in_tile0 > tile_sz[0])) { --runs_left; continue; } - if(problem.GetOutHeight() > 16 && problem.GetOutWidth() > 16 && + if(problem.GetOutHeight_() > 16 && problem.GetOutWidth_() > 16 && ((result.in_tile1 == 8 && result.in_tile0 == 8) || (result.grp_tile0 == 8 && result.grp_tile1 == 8))) { --runs_left; continue; } - if(problem.GetOutWidth() > 32 && result.in_tile1 > result.in_tile0) + if(problem.GetOutWidth_() > 32 && result.in_tile1 > result.in_tile0) { --runs_left; continue; @@ -576,7 +576,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx for(int o_t = 0; o_t < n_out_tls; ++o_t) { result.n_out_pix_tiles = n_out_tiles_rg[o_t]; - if(problem.GetOutChannels() < result.n_out_pix_tiles) + if(problem.GetOutChannels_() < result.n_out_pix_tiles) { --runs_left; continue; @@ -585,7 +585,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx for(int i_t = 0; i_t < n_in_tls; ++i_t) { result.n_in_data_tiles = n_in_tiles_rg[i_t]; - if(problem.GetInChannels() < result.n_in_data_tiles) + if(problem.GetInChannels_() < result.n_in_data_tiles) { --runs_left; continue; @@ -595,7 +595,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx { result.n_stacks = n_in_stacks_sz[s]; - if(result.n_stacks > problem.GetBatchSize()) + if(result.n_stacks > problem.GetBatchSize_()) { --runs_left; continue; @@ -665,7 +665,7 @@ ConvOclDirectFwdLegacyExhaustiveSearch::SearchImpl(const ConvolutionContext& ctx int ret = -1; double default_time = std::numeric_limits::max(); const auto default_config = GetDefaultPerformanceConfig(ctx, problem); - if(problem.GetWeightsWidth() == 1 && problem.GetWeightsHeight() == 1 && + if(problem.GetWeightsWidth_() == 1 && problem.GetWeightsHeight_() == 1 && problem.GetGroupCount() == 1) // Group conv: None 1x1 version yet, fallback to universal kernel. { diff --git a/src/solver/conv_ocl_dir2Dfwd_fused.cpp b/src/solver/conv_ocl_dir2Dfwd_fused.cpp index 8d599d0bca..72d393a8dd 100644 --- a/src/solver/conv_ocl_dir2Dfwd_fused.cpp +++ b/src/solver/conv_ocl_dir2Dfwd_fused.cpp @@ -56,7 +56,7 @@ ConvOclDirectFwdFused::Search(const FusionContext& context, .weights; const auto& tensors = miopen::ConvFwdTensors{fusion_invoke_params.inDesc, fusion_invoke_params.in, - conv_problem.conv_problem.GetWeights(), + conv_problem.GetWeights(), wei_ocl_ptr, fusion_invoke_params.outDesc, fusion_invoke_params.out}; diff --git a/src/solver/conv_ocl_dir2Dfwdgen.cpp b/src/solver/conv_ocl_dir2Dfwdgen.cpp index f014c832aa..06399252ea 100644 --- a/src/solver/conv_ocl_dir2Dfwdgen.cpp +++ b/src/solver/conv_ocl_dir2Dfwdgen.cpp @@ -61,14 +61,14 @@ bool ConvOclDirectFwdGen::IsApplicable(const ConvolutionContext& ctx, { // Factored out from ConvolutionDescriptor::IsDirectSupported(), which is now dissmissed. const auto& p = problem; // alias const bool supported = - ((p.GetWeightsHeight() == p.GetWeightsWidth()) - && ((p.GetWeightsHeight() == 3 && p.GetKernelStrideH() <= 2 && p.GetKernelStrideW() <= 2) - || p.GetWeightsHeight() == 5 - || p.GetWeightsHeight() == 7 - || p.GetWeightsHeight() == 9 - || p.GetWeightsHeight() == 11)) - || (p.GetWeightsHeight() == 5 - && (p.GetWeightsWidth() == 10 || p.GetWeightsWidth() == 20) + ((p.GetWeightsHeight_() == p.GetWeightsWidth_()) + && ((p.GetWeightsHeight_() == 3 && p.GetKernelStrideH() <= 2 && p.GetKernelStrideW() <= 2) + || p.GetWeightsHeight_() == 5 + || p.GetWeightsHeight_() == 7 + || p.GetWeightsHeight_() == 9 + || p.GetWeightsHeight_() == 11)) + || (p.GetWeightsHeight_() == 5 + && (p.GetWeightsWidth_() == 10 || p.GetWeightsWidth_() == 20) && p.GetKernelStrideH() == 2 && p.GetKernelStrideW() == 2 && p.GetPadH() == 0 @@ -79,7 +79,7 @@ bool ConvOclDirectFwdGen::IsApplicable(const ConvolutionContext& ctx, } { // Workaround for issue 1681 - if(problem.IsFp32() && problem.GetInChannels() > 3) + if(problem.IsFp32() && problem.GetInChannels_() > 3) return false; } @@ -88,9 +88,9 @@ bool ConvOclDirectFwdGen::IsApplicable(const ConvolutionContext& ctx, && problem.GetPadW() == problem.GetPadH() && problem.GetDilationW() == 1 && problem.GetDilationH() == 1 - && (problem.GetWeightsWidth() > 11 - || problem.GetWeightsHeight() > 11 - || (!(problem.GetWeightsWidth() == 1 && problem.GetWeightsHeight() == 1) + && (problem.GetWeightsWidth_() > 11 + || problem.GetWeightsHeight_() > 11 + || (!(problem.GetWeightsWidth_() == 1 && problem.GetWeightsHeight_() == 1) && (problem.GetKernelStrideW() > 1 || problem.GetKernelStrideH() > 1))); // clang-format on } @@ -98,17 +98,17 @@ ConvSolution ConvOclDirectFwdGen::GetSolution(const ConvolutionContext& ctx, const ProblemDescription& problem) const { int n_in_stacks = 0; - if(problem.GetWeightsHeight() == 3 && problem.GetWeightsWidth() == 3) + if(problem.GetWeightsHeight_() == 3 && problem.GetWeightsWidth_() == 3) { // n of input batches - n_in_stacks = ((problem.GetBatchSize() / 4) * 4 == problem.GetBatchSize()) ? 4 - : ((problem.GetBatchSize() / 2) * 2 == problem.GetBatchSize()) ? 2 - : 1; + n_in_stacks = ((problem.GetBatchSize_() / 4) * 4 == problem.GetBatchSize_()) ? 4 + : ((problem.GetBatchSize_() / 2) * 2 == problem.GetBatchSize_()) ? 2 + : 1; } else { // n of input batches - n_in_stacks = ((problem.GetBatchSize() / 2) * 2 == problem.GetBatchSize()) ? 2 : 1; + n_in_stacks = ((problem.GetBatchSize_() / 2) * 2 == problem.GetBatchSize_()) ? 2 : 1; } int n_proc_supertiles = n_in_stacks; // n of prosessing groups auto lg2n_proc_supertiles = @@ -116,11 +116,11 @@ ConvSolution ConvOclDirectFwdGen::GetSolution(const ConvolutionContext& ctx, int n_out_stacks = 1; // n of output sets int n_proc_supertile0 = ((n_in_stacks > 1) ? 32 : 16) / problem.GetKernelStrideW(); // n processor in process supertile - int n_proc_supertile1 = - ((n_in_stacks > 1 && (problem.GetWeightsHeight() >= 11 || problem.GetWeightsWidth() >= 11)) - ? 32 - : 16) / - n_in_stacks; + int n_proc_supertile1 = ((n_in_stacks > 1 && (problem.GetWeightsHeight_() >= 11 || + problem.GetWeightsWidth_() >= 11)) + ? 32 + : 16) / + n_in_stacks; auto lg2n_proc_supertile1 = static_cast(std::ceil(std::log(n_proc_supertile1) / std::log(2))); int ocl_group_sz0 = n_proc_supertile0; @@ -133,20 +133,20 @@ ConvSolution ConvOclDirectFwdGen::GetSolution(const ConvolutionContext& ctx, int n_ins0 = 1; // number of inputs each a from different stack along dim 0 int n_ins1 = 1; // number of inputs each a from different stack along dim 1 - int n_outs = (problem.GetInWidth() >= 384 || - (problem.GetWeightsWidth() >= 11 && problem.GetKernelStrideW() >= 4)) + int n_outs = (problem.GetInWidth_() >= 384 || + (problem.GetWeightsWidth_() >= 11 && problem.GetKernelStrideW() >= 4)) ? 16 : 32; // n outputs per a single input: major parameter - int n_out_pix_horiz = (problem.GetInWidth() < 320 || - (problem.GetWeightsWidth() >= 11 && problem.GetKernelStrideW() >= 4)) + int n_out_pix_horiz = (problem.GetInWidth_() < 320 || + (problem.GetWeightsWidth_() >= 11 && problem.GetKernelStrideW() >= 4)) ? 1 : 2; // n of output px horix per wk-item: major parameter int n_out_pix_vert = 1; // n of output px horix per wk-item: major parameter int n_in_pix_horiz = n_out_pix_horiz; // n of input pix per wk_item int n_in_pix_vert = n_out_pix_vert; // n of input pix per wk_item - int n_v_proc0 = (problem.GetOutWidth() + n_out_pix_horiz - 1) / n_out_pix_horiz; - int n_v_proc1 = (problem.GetOutHeight() + n_out_pix_vert - 1) / n_out_pix_vert; + int n_v_proc0 = (problem.GetOutWidth_() + n_out_pix_horiz - 1) / n_out_pix_horiz; + int n_v_proc1 = (problem.GetOutHeight_() + n_out_pix_vert - 1) / n_out_pix_vert; int big = 1; @@ -160,27 +160,27 @@ ConvSolution ConvOclDirectFwdGen::GetSolution(const ConvolutionContext& ctx, int n_ins = n_ins0 * n_ins1; // number of inputs each a from different stack - n_outs = std::min(n_outs, problem.GetOutChannels()); - n_ins = std::min(n_ins, problem.GetBatchSize()); + n_outs = std::min(n_outs, static_cast(problem.GetOutChannels_())); + n_ins = std::min(n_ins, static_cast(problem.GetBatchSize_())); - n_out_stacks = (n_outs * n_out_stacks <= problem.GetOutChannels()) ? n_out_stacks : 1; - n_in_stacks = (n_ins * n_in_stacks <= problem.GetBatchSize()) ? n_in_stacks : 1; + n_out_stacks = (n_outs * n_out_stacks <= problem.GetOutChannels_()) ? n_out_stacks : 1; + n_in_stacks = (n_ins * n_in_stacks <= problem.GetBatchSize_()) ? n_in_stacks : 1; int total_ins = n_ins * n_in_stacks; int total_outs = n_outs * n_out_stacks; - int n_out_blocks = ((problem.GetOutChannels() + total_outs - 1) / total_outs); - int n_stack_blocks = ((problem.GetBatchSize() + total_ins - 1) / total_ins); + int n_out_blocks = ((problem.GetOutChannels_() + total_outs - 1) / total_outs); + int n_stack_blocks = ((problem.GetBatchSize_() + total_ins - 1) / total_ins); int batch_aligned = 0; #if 1 - if((problem.GetBatchSize() / n_stack_blocks) * n_stack_blocks == problem.GetBatchSize()) + if((problem.GetBatchSize_() / n_stack_blocks) * n_stack_blocks == problem.GetBatchSize_()) { batch_aligned = 1; } #endif int out_aligned = 0; #if 1 - if((problem.GetOutChannels() / total_outs) * total_outs == problem.GetOutChannels()) + if((problem.GetOutChannels_() / total_outs) * total_outs == problem.GetOutChannels_()) { out_aligned = 1; } @@ -212,47 +212,47 @@ ConvSolution ConvOclDirectFwdGen::GetSolution(const ConvolutionContext& ctx, std::string(" -DMLO_OUT_STACKS=") + std::to_string(static_cast(n_out_stacks)) + std::string(" -DMLO_IN_STACKS=") + std::to_string(static_cast(n_in_stacks)) + std::string(" -DMLO_BATCH_SZ=") + - std::to_string(static_cast(problem.GetBatchSize())) + + std::to_string(static_cast(problem.GetBatchSize_())) + std::string(" -DMLO_FLTR_SZ0=") + - std::to_string(static_cast(problem.GetWeightsWidth())) + + std::to_string(static_cast(problem.GetWeightsWidth_())) + std::string(" -DMLO_FLTR_PAD_SZ0=") + std::to_string(static_cast(problem.GetPadW())) + std::string(" -DMLO_FLTR_STRIDE0=") + std::to_string(static_cast(problem.GetKernelStrideW())) + std::string(" -DMLO_FLTR_SZ1=") + - std::to_string(static_cast(problem.GetWeightsHeight())) + + std::to_string(static_cast(problem.GetWeightsHeight_())) + std::string(" -DMLO_FLTR_PAD_SZ1=") + std::to_string(static_cast(problem.GetPadH())) + std::string(" -DMLO_FLTR_STRIDE1=") + std::to_string(static_cast(problem.GetKernelStrideH())) + std::string(" -DMLO_N_OUT_CHNLS=") + std::to_string( - static_cast(problem.GetOutChannels())) // total number of output channels + static_cast(problem.GetOutChannels_())) // total number of output channels + std::string(" -DMLO_OUT_WIDTH=") + - std::to_string(static_cast(problem.GetOutWidth())) + + std::to_string(static_cast(problem.GetOutWidth_())) + std::string(" -DMLO_OUT_HEIGHT=") + - std::to_string(static_cast(problem.GetOutHeight())) + + std::to_string(static_cast(problem.GetOutHeight_())) + std::string(" -DMLO_OUT_STRIDE=") + - std::to_string(static_cast(problem.GetOutStride())) + + std::to_string(static_cast(problem.GetOutStrideH_())) + std::string(" -DMLO_OUT_CHNL_STRIDE=") + - std::to_string(static_cast(problem.GetOutChannelStride())) + + std::to_string(static_cast(problem.GetOutChannelStride_())) + std::string(" -DMLO_OUT_BATCH_STRIDE=") + - std::to_string(static_cast(problem.GetOutBatchStride())) + + std::to_string(static_cast(problem.GetOutBatchStride_())) + std::string(" -DMLO_N_OUT_PIX_SZ0=") + std::to_string(static_cast(n_out_pix_horiz)) + std::string(" -DMLO_N_OUT_PIX_SZ1=") + std::to_string(static_cast(n_out_pix_vert)) + std::string(" -DMLO_N_IN_CHNLS=") + - std::to_string(static_cast(problem.GetInChannels())) + + std::to_string(static_cast(problem.GetInChannels_())) + std::string(" -DMLO_IN_WIDTH=") + - std::to_string(static_cast(problem.GetInWidth())) + + std::to_string(static_cast(problem.GetInWidth_())) + std::string(" -DMLO_IN_HEIGHT=") + - std::to_string(static_cast(problem.GetInHeight())) + + std::to_string(static_cast(problem.GetInHeight_())) + std::string(" -DMLO_IN_STRIDE=") + - std::to_string(static_cast(problem.GetInStride())) + + std::to_string(static_cast(problem.GetInStrideH_())) + std::string(" -DMLO_IN_CHNL_STRIDE=") + - std::to_string(static_cast(problem.GetInChannelStride())) + + std::to_string(static_cast(problem.GetInChannelStride_())) + std::string(" -DMLO_IN_BATCH_STRIDE=") + - std::to_string(static_cast(problem.GetInBatchStride())) + + std::to_string(static_cast(problem.GetInBatchStride_())) + std::string(" -DMLO_N_IN_PIX_SZ0=") + std::to_string( static_cast(n_in_pix_horiz)) // size of output processing group in 0 dim @@ -260,11 +260,13 @@ ConvSolution ConvOclDirectFwdGen::GetSolution(const ConvolutionContext& ctx, std::to_string( static_cast(n_in_pix_vert)) // size of output processing group in 1 dim + std::string(" -DMLO_WEI_SZ=") + - std::to_string(static_cast(problem.GetOutChannels()) * problem.GetInChannels() * - problem.GetWeightsWidth() * problem.GetWeightsHeight()) + + std::to_string(static_cast(problem.GetOutChannels_()) * + problem.GetInChannels_() * problem.GetWeightsWidth_() * + problem.GetWeightsHeight_()) + std::string(" -DMLO_WEIGHTS_STRIDE=") + - std::to_string(static_cast(problem.GetInChannels()) * problem.GetWeightsWidth() * - problem.GetWeightsHeight()) // weights stride + std::to_string(static_cast(problem.GetInChannels_()) * + problem.GetWeightsWidth_() * + problem.GetWeightsHeight_()) // weights stride + std::string(" -DMLO_N_STACKS=") + std::to_string(static_cast(n_stack_blocks)) // n of separate data stacks + std::string(" -DMLO_N_PROCS0=") + diff --git a/src/solver/conv_winoRxS.cpp b/src/solver/conv_winoRxS.cpp index 12b8b81c40..5bce82163c 100644 --- a/src/solver/conv_winoRxS.cpp +++ b/src/solver/conv_winoRxS.cpp @@ -305,8 +305,8 @@ template void PerformanceConfigConvBinWinogradRxS::HeuristicInit(const ConvolutionContext& ctx, const ProblemDescription& problem) { - const auto n_inputs_per_group = problem.GetInChannels() / problem.GetGroupCount(), - n_outputs_per_group = problem.GetOutChannels() / problem.GetGroupCount(); + const auto n_inputs_per_group = problem.GetInChannels_() / problem.GetGroupCount(), + n_outputs_per_group = problem.GetOutChannels_() / problem.GetGroupCount(); if(problem.GetGroupCount() == 1) { n_groups = ctx.GetStream().GetMaxHardwareComputeUnits(); @@ -315,14 +315,14 @@ void PerformanceConfigConvBinWinogradRxS::HeuristicInit(const ConvolutionContext if(problem.direction.IsBackwardWrW()) { - n_groups = GetBestNGroupParam(problem.GetInHeight(), - problem.GetInWidth(), + n_groups = GetBestNGroupParam(problem.GetInHeight_(), + problem.GetInWidth_(), problem.GetDilationH(), problem.GetDilationW(), - problem.GetBatchSize(), // N - n_inputs_per_group, // K - problem.GetWeightsHeight(), - problem.GetWeightsWidth(), + problem.GetBatchSize_(), // N + n_inputs_per_group, // K + problem.GetWeightsHeight_(), + problem.GetWeightsWidth_(), problem.GetPadW(), problem.GetPadH(), n_outputs_per_group, // C @@ -335,17 +335,17 @@ void PerformanceConfigConvBinWinogradRxS::HeuristicInit(const ConvolutionContext } else { - n_groups = GetBestNGroupParam(problem.GetWeightsHeight(), // RxS - problem.GetWeightsWidth(), + n_groups = GetBestNGroupParam(problem.GetWeightsHeight_(), // RxS + problem.GetWeightsWidth_(), problem.GetKernelStrideH(), problem.GetKernelStrideW(), - n_inputs_per_group, // C - n_outputs_per_group, // K - problem.GetOutHeight(), // OHxOW - problem.GetOutWidth(), + n_inputs_per_group, // C + n_outputs_per_group, // K + problem.GetOutHeight_(), // OHxOW + problem.GetOutWidth_(), problem.GetPadW(), problem.GetPadH(), - problem.GetBatchSize(), // N + problem.GetBatchSize_(), // N problem.GetDilationH(), problem.GetDilationW(), ctx.GetStream().GetMaxHardwareComputeUnits(), @@ -642,7 +642,7 @@ static bool IsApplicableBase(const ConvolutionContext& ctx, const ProblemDescrip StartsWith(name, "gfx103") || StartsWith(name, "gfx11"))) return false; - if(name == "gfx90a" && problem.conv_problem.IsGfx90aFp16altRequired()) + if(name == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; // clang-format off @@ -655,37 +655,37 @@ static bool IsApplicableBase(const ConvolutionContext& ctx, const ProblemDescrip return false; // clang-format on - const auto n_inputs_per_group = problem.GetInChannels() / problem.GetGroupCount(), - n_outputs_per_group = problem.GetOutChannels() / problem.GetGroupCount(); + const auto n_inputs_per_group = problem.GetInChannels_() / problem.GetGroupCount(), + n_outputs_per_group = problem.GetOutChannels_() / problem.GetGroupCount(); if(problem.direction.IsBackwardWrW()) { if(problem.GetKernelStrideW() == 2) return false; return IsShaderConstraintsMet(problem, - problem.GetInHeight(), - problem.GetInWidth(), - problem.GetBatchSize(), // N - n_inputs_per_group, // K - problem.GetOutHeight(), - problem.GetOutWidth(), - problem.GetWeightsHeight(), - problem.GetWeightsWidth(), + problem.GetInHeight_(), + problem.GetInWidth_(), + problem.GetBatchSize_(), // N + n_inputs_per_group, // K + problem.GetOutHeight_(), + problem.GetOutWidth_(), + problem.GetWeightsHeight_(), + problem.GetWeightsWidth_(), n_outputs_per_group, // C name); } else { return IsShaderConstraintsMet(problem, - problem.GetWeightsHeight(), // RxS - problem.GetWeightsWidth(), - n_inputs_per_group, // C - n_outputs_per_group, // K - problem.GetInHeight(), // HxW - problem.GetInWidth(), - problem.GetOutHeight(), // OHxOW - problem.GetOutWidth(), - problem.GetBatchSize(), // N + problem.GetWeightsHeight_(), // RxS + problem.GetWeightsWidth_(), + n_inputs_per_group, // C + n_outputs_per_group, // K + problem.GetInHeight_(), // HxW + problem.GetInWidth_(), + problem.GetOutHeight_(), // OHxOW + problem.GetOutWidth_(), + problem.GetBatchSize_(), // N name); } } @@ -927,8 +927,8 @@ ConvSolution ConvBinWinoRxS::GetSolution( &pad_W); N /= group_cnt; K /= group_cnt; - pad_H = problem.conv_problem.GetConv().GetConvPads()[0]; - pad_W = problem.conv_problem.GetConv().GetConvPads()[1]; + pad_H = problem.GetConv().GetConvPads()[0]; + pad_W = problem.GetConv().GetConvPads()[1]; d_layout = GetGroupConvLayout(GetSwappedNCLayout(GetMemLayout_t(problem.GetInLayout())), true); diff --git a/src/solver/conv_winoRxS_fused.cpp b/src/solver/conv_winoRxS_fused.cpp index 6007023087..e242d5d0d8 100644 --- a/src/solver/conv_winoRxS_fused.cpp +++ b/src/solver/conv_winoRxS_fused.cpp @@ -177,19 +177,19 @@ bool ConvBinWinogradRxSf2x3g1Fused::IsApplicable(const FusionContext& context, return false; // clang-format on - const auto group_count = conv_problem.conv_problem.GetGroupCount(); + const auto group_count = conv_problem.GetGroupCount(); if(group_count != 1) return false; - const auto W = conv_problem.conv_problem.GetInWidth(); - const auto H = conv_problem.conv_problem.GetInHeight(); - const auto C = conv_problem.conv_problem.GetInChannels(); - const auto N = conv_problem.conv_problem.GetInBatchSize(); - const auto K = conv_problem.conv_problem.GetOutChannels(); - const auto R = conv_problem.conv_problem.GetWeightsHeight(); - const auto S = conv_problem.conv_problem.GetWeightsWidth(); - const auto OH = conv_problem.conv_problem.GetOutHeight(); - const auto OW = conv_problem.conv_problem.GetOutWidth(); + const auto W = conv_problem.GetInWidth_(); + const auto H = conv_problem.GetInHeight_(); + const auto C = conv_problem.GetInChannels_(); + const auto N = conv_problem.GetInBatchSize_(); + const auto K = conv_problem.GetOutChannels_(); + const auto R = conv_problem.GetWeightsHeight_(); + const auto S = conv_problem.GetWeightsWidth_(); + const auto OH = conv_problem.GetOutHeight_(); + const auto OW = conv_problem.GetOutWidth_(); return IsWinogradV21Preferred<2, 3>(name, conv_problem) ? IsShaderConstraintsMetV21(conv_problem, R, S, C, K, H, W, OH, OW, N) @@ -251,8 +251,8 @@ ConvSolution ConvBinWinogradRxSf2x3g1Fused::GetSolution(const FusionContext& con kernel.kernel_file += kernel_postfix + ".s"; result.construction_params.push_back(kernel); - const auto x = conv_problem.conv_problem.GetWeightsWidth(); - const auto y = conv_problem.conv_problem.GetWeightsHeight(); + const auto x = conv_problem.GetWeightsWidth_(); + const auto y = conv_problem.GetWeightsHeight_(); if(x == 3 && y == 3) result.weight = 100; diff --git a/src/solver/conv_wino_fury_RxS.cpp b/src/solver/conv_wino_fury_RxS.cpp index 743064c65b..035215a48d 100644 --- a/src/solver/conv_wino_fury_RxS.cpp +++ b/src/solver/conv_wino_fury_RxS.cpp @@ -94,10 +94,10 @@ class ShaderModel : public UnifiedDescriptionConv2d Hs{Ceil(out_h, Toh)}, We{Tow * (Ceil(out_w, Tow) + Ceil(Tw, Tow) - 1)}, - W{static_cast(problem.direction.IsBackwardWrW() ? problem.GetOutWidth() - : problem.GetInWidth())}, - H{static_cast(problem.direction.IsBackwardWrW() ? problem.GetOutHeight() - : problem.GetInHeight())}, + W{static_cast(problem.direction.IsBackwardWrW() ? problem.GetOutWidth_() + : problem.GetInWidth_())}, + H{static_cast(problem.direction.IsBackwardWrW() ? problem.GetOutHeight_() + : problem.GetInHeight_())}, d_H_clip{static_cast(static_cast(Hs * Toh) - pad_h)}, d_W_clip{static_cast(We - pad_w)}, diff --git a/src/solver/fft.cpp b/src/solver/fft.cpp index 4620be887e..9a3c7858cc 100644 --- a/src/solver/fft.cpp +++ b/src/solver/fft.cpp @@ -112,7 +112,7 @@ bool fft::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& pr std::ignore = ctx; // disable running any FFT based convolutions by checking this env variable - if(problem.direction.IsBackwardWrW() || !problem.conv_problem.IsFp32()) + if(problem.direction.IsBackwardWrW() || !problem.IsFp32()) return false; if(!problem.IsLayoutDefault()) @@ -121,10 +121,10 @@ bool fft::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& pr } const auto is_fwd = problem.direction.IsForward(); - decltype(auto) conv = problem.conv_problem.GetConv(); - decltype(auto) xDesc = is_fwd ? problem.conv_problem.GetIn() : problem.conv_problem.GetOut(); - decltype(auto) yDesc = is_fwd ? problem.conv_problem.GetOut() : problem.conv_problem.GetIn(); - decltype(auto) wDesc = problem.conv_problem.GetWeights(); + decltype(auto) conv = problem.GetConv(); + decltype(auto) xDesc = is_fwd ? problem.GetIn() : problem.GetOut(); + decltype(auto) yDesc = is_fwd ? problem.GetOut() : problem.GetIn(); + decltype(auto) wDesc = problem.GetWeights(); if(conv.GetSpatialDimension() != 2 || conv.group_count != 1 || !miopen::all_of(conv.GetConvDilations(), [](auto v) { return v == 1; })) @@ -161,9 +161,9 @@ bool fft::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& pr size_t fft::GetWorkspaceSize(const ExecutionContext&, const ProblemDescription& problem) const { const auto fwd = problem.direction.IsForward(); - decltype(auto) xDesc = fwd ? problem.conv_problem.GetIn() : problem.conv_problem.GetOut(); - decltype(auto) yDesc = fwd ? problem.conv_problem.GetOut() : problem.conv_problem.GetIn(); - decltype(auto) wDesc = problem.conv_problem.GetWeights(); + decltype(auto) xDesc = fwd ? problem.GetIn() : problem.GetOut(); + decltype(auto) yDesc = fwd ? problem.GetOut() : problem.GetIn(); + decltype(auto) wDesc = problem.GetWeights(); int in_n, in_c, in_h, in_w; std::tie(in_n, in_c, in_h, in_w) = miopen::tien<4>(xDesc.GetLengths()); @@ -202,12 +202,12 @@ ConvSolution fft::GetSolution(const ExecutionContext& ctx, const ProblemDescript { std::ignore = ctx; - int in_n = problem.GetBatchSize(); - int in_c = problem.GetInChannels(); - int in_h = problem.GetInHeight(); - int in_w = problem.GetInWidth(); - int out_n = problem.GetBatchSize(); - int out_c = problem.GetOutChannels(); + int in_n = problem.GetBatchSize_(); + int in_c = problem.GetInChannels_(); + int in_h = problem.GetInHeight_(); + int in_w = problem.GetInWidth_(); + int out_n = problem.GetBatchSize_(); + int out_c = problem.GetOutChannels_(); const int N = FFTConvParams::TileSize(in_h, in_w); const int NumKernels = FFTConvParams::NumKernels; diff --git a/src/solver/pooling/forwardNaive.cpp b/src/solver/pooling/forwardNaive.cpp index 64e3f3dd51..d8a13a330f 100644 --- a/src/solver/pooling/forwardNaive.cpp +++ b/src/solver/pooling/forwardNaive.cpp @@ -132,7 +132,7 @@ PoolingForwardNaive::GetSolution(const ExecutionContext& context, /// not require widening to size_t prior mul, but (d_stride * dim * dim) /// requires it because the total number of muls is 4. - const auto spatial_dim = is2d ? 2 : 3; + const auto spatial_dim = is2d ? 2U : 3U; uint32_t all_n, all_c, bot_d, bot_h, bot_w; std::tie(all_n, all_c, bot_d, bot_h, bot_w) = miopen::GetNCDHW(spatial_dim, bot.GetLengths()); uint32_t bot_w_stride, bot_h_stride, bot_d_stride; diff --git a/test/conv_common.hpp b/test/conv_common.hpp index 8e5c4157d2..286319cb83 100644 --- a/test/conv_common.hpp +++ b/test/conv_common.hpp @@ -143,14 +143,14 @@ static inline bool skip_config(miopen::Handle& handle, ctx.general_compile_options = ""; ctx.disable_perfdb_access = true; ctx.SetStream(&handle); - problem.conv_problem.SetupFloats(ctx); + problem.SetupFloats(ctx); return ctx.GetStream().GetDeviceName() == "gfx908" && problem.Is2d() && problem.IsFp16() && problem.IsLayoutDefault() && ctx.use_hip_kernels && problem.GetGroupCount() == 1 && - problem.GetBatchSize() == 1 && problem.GetInChannels() == 192 && - problem.GetInHeight() == 28 && problem.GetInWidth() == 28 && - problem.GetOutChannels() == 1 && problem.GetWeightsHeight() == 3 && - problem.GetWeightsWidth() == 3 && problem.GetPadW() == 1 && problem.GetPadH() == 1 && + problem.GetBatchSize_() == 1 && problem.GetInChannels_() == 192 && + problem.GetInHeight_() == 28 && problem.GetInWidth_() == 28 && + problem.GetOutChannels_() == 1 && problem.GetWeightsHeight_() == 3 && + problem.GetWeightsWidth_() == 3 && problem.GetPadW() == 1 && problem.GetPadH() == 1 && problem.GetKernelStrideW() == 1 && problem.GetKernelStrideH() == 1 && problem.GetDilationW() == 1 && problem.GetDilationH() == 1; } diff --git a/test/solver.cpp b/test/solver.cpp index 30baea3ec0..3777daf6ae 100644 --- a/test/solver.cpp +++ b/test/solver.cpp @@ -51,7 +51,7 @@ class TrivialTestSolver final : public solver::ConvSolver bool IsApplicable(const ConvolutionContext&, const ProblemDescription& problem) const override { - return problem.GetInWidth() == 1; + return problem.GetInWidth_() == 1; } solver::ConvSolution GetSolution(const ConvolutionContext&,