Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC] Replace miopen::ProblemDescription with conv::ProblemDescription, part 3 #2303

Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
223f8e7
ProblemDescription: use data types from conv::ProblemDescription for …
averinevg Jun 21, 2023
742f46c
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jun 22, 2023
6cd365f
Revert GetBackwardPad
averinevg Jun 22, 2023
e36f0b8
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jun 22, 2023
17d41b5
Remove unused code
averinevg Jun 22, 2023
b98dda6
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jun 23, 2023
95653b2
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jun 27, 2023
64fbd3e
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jul 11, 2023
5ee0c09
Revert getters
averinevg Jul 11, 2023
e795657
Rename getters
averinevg Jul 11, 2023
6cdb6b8
Fix formatting
averinevg Jul 11, 2023
6a8276c
Fix formatting
averinevg Jul 11, 2023
747712f
Make ProblemDescription derived from conv::ProblemDescription
averinevg Jul 12, 2023
c9dc957
Fix formatting
averinevg Jul 12, 2023
8e56a57
Remove unused getters
averinevg Jul 12, 2023
4b4bf33
Fix fin
averinevg Jul 13, 2023
936b5ee
Fix formatting
averinevg Jul 13, 2023
4e9603b
conv::ProblemDescription: change return data type of getters
averinevg Jul 13, 2023
cb4b8c0
Fix formatting
averinevg Jul 13, 2023
23edc24
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jul 14, 2023
7de3724
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jul 17, 2023
6a74372
Use getters from conv::ProblemDescription
averinevg Jul 18, 2023
0e28cf3
Fix formatting
averinevg Jul 18, 2023
9c2a4c7
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jul 19, 2023
a1afdac
Use getters from conv::ProblemDescription
averinevg Jul 19, 2023
f434410
Fix formatting
averinevg Jul 19, 2023
d270a7a
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jul 25, 2023
51ca223
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jul 26, 2023
d06e307
Use getters from conv::ProblemDescription
averinevg Jul 26, 2023
2888e32
Fix formatting
averinevg Jul 26, 2023
9f58db9
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jul 27, 2023
f2b9886
Use getters from conv::ProblemDescription
averinevg Jul 27, 2023
6e273f1
Use getters from conv::ProblemDescription
averinevg Jul 28, 2023
cb030b4
Fix formatting
averinevg Jul 28, 2023
ddfa4d6
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Jul 29, 2023
cc0c4cc
Use getters from conv::ProblemDescription
averinevg Jul 29, 2023
bb432f2
Fix formatting
averinevg Jul 29, 2023
d33e626
conv::ProblemDescription: replace data types for getters
averinevg Aug 1, 2023
5ad8da3
Fix formatting
averinevg Aug 1, 2023
5c857ea
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Aug 2, 2023
931203d
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Aug 10, 2023
f1565df
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Aug 15, 2023
ab210e5
Fix 3d group forward convolution solver
averinevg Aug 15, 2023
e84336e
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Aug 21, 2023
e93970f
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Aug 22, 2023
e3ad58e
Merge branch 'develop' into ea_replace_problem_description_with_conv_…
averinevg Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 33 additions & 37 deletions src/conv/heuristics/ai_heuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,43 +153,42 @@ class Gfx908Model : public Model
const ConvolutionContext& ctx) const override
{
// check if problem is of the kind TunaNet was trained to handle
if(!problem.conv_problem.Is2d())
if(!problem.Is2d())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Problem not 2D");
return false;
}
if(problem.conv_problem.GetGroupCount() != 1)
if(problem.GetGroupCount() != 1)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Group count not 1");
return false;
}
if(problem.conv_problem.GetInLayout() != "NCHW" &&
problem.conv_problem.GetInLayout() != "NCDHW")
if(problem.GetInLayout() != "NCHW" && problem.GetInLayout() != "NCDHW")
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Layout not supported");
return false;
}
if(problem.conv_problem.GetWeightsHeight() != problem.conv_problem.GetWeightsWidth())
if(problem.GetWeightsHeight_() != problem.GetWeightsWidth_())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Filters must be square (fil_h == fil_w)");
return false;
}
if(problem.conv_problem.GetPadH() != problem.conv_problem.GetPadW())
if(problem.GetPadH() != problem.GetPadW())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Padding must be equal along all axes");
return false;
}
if(problem.conv_problem.GetKernelStrideH() != problem.conv_problem.GetKernelStrideW())
if(problem.GetKernelStrideH() != problem.GetKernelStrideW())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Stride must be equal along all axes");
return false;
}
if(problem.conv_problem.GetDilationH() != 1 || problem.conv_problem.GetDilationW() != 1)
if(problem.GetDilationH() != 1 || problem.GetDilationW() != 1)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Dilation must be 1");
return false;
}
const auto& data_type = problem.conv_problem.GetInDataType();
const auto data_type = problem.GetInDataType();
if(data_type != miopenFloat && data_type != miopenHalf && data_type != miopenBFloat16)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Unsupported data type");
Expand Down Expand Up @@ -219,37 +218,34 @@ class Gfx908Model : public Model
protected:
std::vector<float> ToFeatures(const ProblemDescription& problem) const override
{
const auto& conv_problem = problem.conv_problem;
const bool isFwd = conv_problem.GetDirection() == conv::Direction::Forward;
const bool isFwd = problem.GetDirection() == conv::Direction::Forward;
std::vector<float> features = {
static_cast<float>(isFwd ? conv_problem.GetInChannels()
: conv_problem.GetOutChannels()),
static_cast<float>(isFwd ? conv_problem.GetInDepth() : conv_problem.GetOutDepth()),
static_cast<float>(isFwd ? conv_problem.GetInHeight() : conv_problem.GetOutHeight()),
static_cast<float>(isFwd ? conv_problem.GetInWidth() : conv_problem.GetOutWidth()),
static_cast<float>(conv_problem.GetWeightsDepth()),
static_cast<float>(conv_problem.GetWeightsHeight()),
static_cast<float>(conv_problem.GetWeightsWidth()),
static_cast<float>(isFwd ? conv_problem.GetOutChannels()
: conv_problem.GetInChannels()),
static_cast<float>(isFwd ? conv_problem.GetOutDepth() : conv_problem.GetInDepth()),
static_cast<float>(isFwd ? conv_problem.GetOutHeight() : conv_problem.GetInHeight()),
static_cast<float>(isFwd ? conv_problem.GetOutWidth() : conv_problem.GetInWidth()),
static_cast<float>(conv_problem.GetOutBatchSize()),
static_cast<float>(isFwd ? problem.GetInChannels_() : problem.GetOutChannels_()),
static_cast<float>(isFwd ? problem.GetInDepth_() : problem.GetOutDepth_()),
static_cast<float>(isFwd ? problem.GetInHeight_() : problem.GetOutHeight_()),
static_cast<float>(isFwd ? problem.GetInWidth_() : problem.GetOutWidth_()),
static_cast<float>(problem.GetWeightsDepth_()),
static_cast<float>(problem.GetWeightsHeight_()),
static_cast<float>(problem.GetWeightsWidth_()),
static_cast<float>(isFwd ? problem.GetOutChannels_() : problem.GetInChannels_()),
static_cast<float>(isFwd ? problem.GetOutDepth_() : problem.GetInDepth_()),
static_cast<float>(isFwd ? problem.GetOutHeight_() : problem.GetInHeight_()),
static_cast<float>(isFwd ? problem.GetOutWidth_() : problem.GetInWidth_()),
static_cast<float>(problem.GetOutBatchSize_()),
static_cast<float>(1), // TunaNet was trained on a dataset of 2D
// problems where PadD was incorrectly set to 1
static_cast<float>(conv_problem.GetPadH()),
static_cast<float>(conv_problem.GetPadW()),
static_cast<float>(problem.GetPadH()),
static_cast<float>(problem.GetPadW()),
static_cast<float>(1), // TunaNet was trained on a dataset of 2D
// problems where StrideD was incorrectly set to 1
static_cast<float>(conv_problem.GetKernelStrideH()),
static_cast<float>(conv_problem.GetKernelStrideW()),
static_cast<float>(conv_problem.GetDilationH()),
static_cast<float>(conv_problem.GetDilationW()),
static_cast<float>(metadata.EncodeLayout(conv_problem.GetInLayout())),
static_cast<float>(metadata.EncodePrecision(conv_problem.GetInDataType())),
static_cast<float>(metadata.EncodeDirection(conv_problem.GetDirection())),
static_cast<float>(conv_problem.GetGroupCount())};
static_cast<float>(problem.GetKernelStrideH()),
static_cast<float>(problem.GetKernelStrideW()),
static_cast<float>(problem.GetDilationH()),
static_cast<float>(problem.GetDilationW()),
static_cast<float>(metadata.EncodeLayout(problem.GetInLayout())),
static_cast<float>(metadata.EncodePrecision(problem.GetInDataType())),
static_cast<float>(metadata.EncodeDirection(problem.GetDirection())),
static_cast<float>(problem.GetGroupCount())};

// normalize
for(size_t i = 0; i < features.size(); ++i)
Expand All @@ -271,7 +267,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,

std::string est_name = ":memory:" + device;
auto& db = AnyRamDb::GetCached(est_name);
auto db_res = db.FindRecord(problem.conv_problem);
auto db_res = db.FindRecord(static_cast<const conv::ProblemDescription&>(problem));
if(db_res)
{
MIOPEN_LOG_I2("Cached heuristic result found");
Expand Down Expand Up @@ -320,7 +316,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
sol.push_back(sol_id.Value());
any_sol.push_back(sol_id.Value());
}
db.StoreRecord(problem.conv_problem, any_sol);
db.StoreRecord(static_cast<const conv::ProblemDescription&>(problem), any_sol);
if(miopen::IsLogging(LoggingLevel::Info2))
{
std::stringstream ss;
Expand Down
2 changes: 1 addition & 1 deletion src/conv/invokers/impl_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ InvokerFactory MakeImplGemmDataInvokerFactory(const miopen::ProblemDescription&
if(problem.direction.IsBackwardWrW())
MIOPEN_THROW("MakeImplGemmDataInvokerFactory shouldn't be used for WrW invokers.");

const auto& conv = problem.conv_problem.GetConv();
const auto& conv = problem.GetConv();
const auto& lowp_quant = conv.lowp_quant;

return [conv, lowp_quant](const std::vector<Kernel>& kernels) {
Expand Down
Loading