From c233b80e2ebec5dff7db21368dd5926d2b5555fe Mon Sep 17 00:00:00 2001 From: Evgenii Averin <86725875+averinevg@users.noreply.github.com> Date: Tue, 1 Oct 2024 23:06:54 +0200 Subject: [PATCH] [BugFix] Fix ConvDirectNaiveConvFwd not applicable after #3213 (#3284) --- src/conv/problem_description.cpp | 18 ++++++++++++++++++ .../miopen/conv/problem_description.hpp | 3 +++ 2 files changed, 21 insertions(+) diff --git a/src/conv/problem_description.cpp b/src/conv/problem_description.cpp index 46250ca19c..f868e0544d 100644 --- a/src/conv/problem_description.cpp +++ b/src/conv/problem_description.cpp @@ -119,6 +119,24 @@ std::string ProblemDescription::GetAlphaBetaCaseStr() const } } +void ProblemDescription::HeuristicUpdateLayouts() +{ + static const std::vector supported_layouts = {"NCHW", "NHWC", "CHWN", "NCDHW"}; + + for(const std::string& layout : supported_layouts) + { + if(in.IsPossibleLayout4D5D(layout) && out.IsPossibleLayout4D5D(layout) && + weights.IsPossibleLayout4D5D(layout)) + { + in_layout = layout; + weights_layout = layout; + out_layout = layout; + return; + } + } + // If we did not find consistent layout, leave them as-is +} + void ProblemDescription::MakeNetworkConfig(std::string& conf_key) const { std::ostringstream ss; diff --git a/src/include/miopen/conv/problem_description.hpp b/src/include/miopen/conv/problem_description.hpp index 5d1d89ce23..f0ac17d0b0 100644 --- a/src/include/miopen/conv/problem_description.hpp +++ b/src/include/miopen/conv/problem_description.hpp @@ -166,6 +166,7 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase beta(beta_), alpha_beta_case(ClassifyAlphaBeta(alpha, beta)) { + HeuristicUpdateLayouts(); } // Conv descriptor getters @@ -368,6 +369,8 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase out.AllLengthsFitIntoInt(); } + void HeuristicUpdateLayouts(); + void MakeNetworkConfig(std::string& conf_key) const; NetworkConfig MakeNetworkConfig() const override