Skip to content

Commit

Permalink
[NFC] Replace miopen::ProblemDescription with conv::ProblemDescriptio…
Browse files Browse the repository at this point in the history
…n, part 3 (#2303)
  • Loading branch information
averinevg authored Aug 28, 2023
1 parent b85c7c3 commit 21df5bf
Show file tree
Hide file tree
Showing 89 changed files with 1,904 additions and 2,049 deletions.
70 changes: 33 additions & 37 deletions src/conv/heuristics/ai_heuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,43 +153,42 @@ class Gfx908Model : public Model
const ConvolutionContext& ctx) const override
{
// check if problem is of the kind TunaNet was trained to handle
if(!problem.conv_problem.Is2d())
if(!problem.Is2d())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Problem not 2D");
return false;
}
if(problem.conv_problem.GetGroupCount() != 1)
if(problem.GetGroupCount() != 1)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Group count not 1");
return false;
}
if(problem.conv_problem.GetInLayout() != "NCHW" &&
problem.conv_problem.GetInLayout() != "NCDHW")
if(problem.GetInLayout() != "NCHW" && problem.GetInLayout() != "NCDHW")
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Layout not supported");
return false;
}
if(problem.conv_problem.GetWeightsHeight() != problem.conv_problem.GetWeightsWidth())
if(problem.GetWeightsHeight_() != problem.GetWeightsWidth_())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Filters must be square (fil_h == fil_w)");
return false;
}
if(problem.conv_problem.GetPadH() != problem.conv_problem.GetPadW())
if(problem.GetPadH() != problem.GetPadW())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Padding must be equal along all axes");
return false;
}
if(problem.conv_problem.GetKernelStrideH() != problem.conv_problem.GetKernelStrideW())
if(problem.GetKernelStrideH() != problem.GetKernelStrideW())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Stride must be equal along all axes");
return false;
}
if(problem.conv_problem.GetDilationH() != 1 || problem.conv_problem.GetDilationW() != 1)
if(problem.GetDilationH() != 1 || problem.GetDilationW() != 1)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Dilation must be 1");
return false;
}
const auto& data_type = problem.conv_problem.GetInDataType();
const auto data_type = problem.GetInDataType();
if(data_type != miopenFloat && data_type != miopenHalf && data_type != miopenBFloat16)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Unsupported data type");
Expand Down Expand Up @@ -219,37 +218,34 @@ class Gfx908Model : public Model
protected:
std::vector<float> ToFeatures(const ProblemDescription& problem) const override
{
const auto& conv_problem = problem.conv_problem;
const bool isFwd = conv_problem.GetDirection() == conv::Direction::Forward;
const bool isFwd = problem.GetDirection() == conv::Direction::Forward;
std::vector<float> features = {
static_cast<float>(isFwd ? conv_problem.GetInChannels()
: conv_problem.GetOutChannels()),
static_cast<float>(isFwd ? conv_problem.GetInDepth() : conv_problem.GetOutDepth()),
static_cast<float>(isFwd ? conv_problem.GetInHeight() : conv_problem.GetOutHeight()),
static_cast<float>(isFwd ? conv_problem.GetInWidth() : conv_problem.GetOutWidth()),
static_cast<float>(conv_problem.GetWeightsDepth()),
static_cast<float>(conv_problem.GetWeightsHeight()),
static_cast<float>(conv_problem.GetWeightsWidth()),
static_cast<float>(isFwd ? conv_problem.GetOutChannels()
: conv_problem.GetInChannels()),
static_cast<float>(isFwd ? conv_problem.GetOutDepth() : conv_problem.GetInDepth()),
static_cast<float>(isFwd ? conv_problem.GetOutHeight() : conv_problem.GetInHeight()),
static_cast<float>(isFwd ? conv_problem.GetOutWidth() : conv_problem.GetInWidth()),
static_cast<float>(conv_problem.GetOutBatchSize()),
static_cast<float>(isFwd ? problem.GetInChannels_() : problem.GetOutChannels_()),
static_cast<float>(isFwd ? problem.GetInDepth_() : problem.GetOutDepth_()),
static_cast<float>(isFwd ? problem.GetInHeight_() : problem.GetOutHeight_()),
static_cast<float>(isFwd ? problem.GetInWidth_() : problem.GetOutWidth_()),
static_cast<float>(problem.GetWeightsDepth_()),
static_cast<float>(problem.GetWeightsHeight_()),
static_cast<float>(problem.GetWeightsWidth_()),
static_cast<float>(isFwd ? problem.GetOutChannels_() : problem.GetInChannels_()),
static_cast<float>(isFwd ? problem.GetOutDepth_() : problem.GetInDepth_()),
static_cast<float>(isFwd ? problem.GetOutHeight_() : problem.GetInHeight_()),
static_cast<float>(isFwd ? problem.GetOutWidth_() : problem.GetInWidth_()),
static_cast<float>(problem.GetOutBatchSize_()),
static_cast<float>(1), // TunaNet was trained on a dataset of 2D
// problems where PadD was incorrectly set to 1
static_cast<float>(conv_problem.GetPadH()),
static_cast<float>(conv_problem.GetPadW()),
static_cast<float>(problem.GetPadH()),
static_cast<float>(problem.GetPadW()),
static_cast<float>(1), // TunaNet was trained on a dataset of 2D
// problems where StrideD was incorrectly set to 1
static_cast<float>(conv_problem.GetKernelStrideH()),
static_cast<float>(conv_problem.GetKernelStrideW()),
static_cast<float>(conv_problem.GetDilationH()),
static_cast<float>(conv_problem.GetDilationW()),
static_cast<float>(metadata.EncodeLayout(conv_problem.GetInLayout())),
static_cast<float>(metadata.EncodePrecision(conv_problem.GetInDataType())),
static_cast<float>(metadata.EncodeDirection(conv_problem.GetDirection())),
static_cast<float>(conv_problem.GetGroupCount())};
static_cast<float>(problem.GetKernelStrideH()),
static_cast<float>(problem.GetKernelStrideW()),
static_cast<float>(problem.GetDilationH()),
static_cast<float>(problem.GetDilationW()),
static_cast<float>(metadata.EncodeLayout(problem.GetInLayout())),
static_cast<float>(metadata.EncodePrecision(problem.GetInDataType())),
static_cast<float>(metadata.EncodeDirection(problem.GetDirection())),
static_cast<float>(problem.GetGroupCount())};

// normalize
for(size_t i = 0; i < features.size(); ++i)
Expand All @@ -271,7 +267,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,

std::string est_name = ":memory:" + device;
auto& db = AnyRamDb::GetCached(est_name);
auto db_res = db.FindRecord(problem.conv_problem);
auto db_res = db.FindRecord(static_cast<const conv::ProblemDescription&>(problem));
if(db_res)
{
MIOPEN_LOG_I2("Cached heuristic result found");
Expand Down Expand Up @@ -320,7 +316,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
sol.push_back(sol_id.Value());
any_sol.push_back(sol_id.Value());
}
db.StoreRecord(problem.conv_problem, any_sol);
db.StoreRecord(static_cast<const conv::ProblemDescription&>(problem), any_sol);
if(miopen::IsLogging(LoggingLevel::Info2))
{
std::stringstream ss;
Expand Down
2 changes: 1 addition & 1 deletion src/conv/invokers/impl_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ InvokerFactory MakeImplGemmDataInvokerFactory(const miopen::ProblemDescription&
if(problem.direction.IsBackwardWrW())
MIOPEN_THROW("MakeImplGemmDataInvokerFactory shouldn't be used for WrW invokers.");

const auto& conv = problem.conv_problem.GetConv();
const auto& conv = problem.GetConv();
const auto& lowp_quant = conv.lowp_quant;

return [conv, lowp_quant](const std::vector<Kernel>& kernels) {
Expand Down
Loading

0 comments on commit 21df5bf

Please sign in to comment.