Skip to content

Commit

Permalink
TunaNet Integration: MI250x (#2421)
Browse files Browse the repository at this point in the history
  • Loading branch information
msaudulhassan authored Oct 24, 2023
1 parent 102fbee commit 411b345
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 6 deletions.
109 changes: 104 additions & 5 deletions src/conv/heuristics/ai_heuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class Model
virtual std::vector<float> ToFeatures(const ProblemDescription& problem) const = 0;
};

class Gfx908Model : public Model
class Gfx908Model final : public Model
{
public:
Gfx908Model() : Model("gfx908") {}
Expand Down Expand Up @@ -255,7 +255,106 @@ class Gfx908Model : public Model
}
};

std::unique_ptr<Model> GetModel(const std::string&) { return std::make_unique<Gfx908Model>(); }
class Gfx90aModel final : public Model
{
public:
Gfx90aModel() : Model("gfx90a") {}
bool IsProblemSupported(const ProblemDescription& problem,
const ExecutionContext& ctx) const override
{
// check if problem is of the kind TunaNet was trained to handle
if(!problem.Is2d())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Problem not 2D");
return false;
}
if(problem.GetInLayout() != "NCHW")
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Layout not supported");
return false;
}
if(problem.GetKernelStrideH() != problem.GetKernelStrideW())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Stride must be equal along all axes");
return false;
}
if(problem.GetDilationH() != problem.GetDilationW())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Dilation must be 1");
return false;
}
if(problem.GetBias() != 0)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Bias must be 0");
return false;
}
const auto data_type = problem.GetInDataType();
if(data_type != miopenFloat && data_type != miopenHalf && data_type != miopenBFloat16)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Unsupported data type");
return false;
}

// check if the context is s.t. no solver TunaNet may predict would be applicable
size_t applicable_solvers = 0;
for(const auto& solver_name : metadata.solver_map)
{
auto solver_id = solver::Id{solver_name.second};
auto solver = solver_id.GetSolver();
if(solver.IsApplicable(ctx, problem))
{
applicable_solvers++;
break;
}
}
if(applicable_solvers == 0)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: No solver that TunaNet may predict applies");
return false;
}
MIOPEN_LOG_I2("TunaNet Applicable");
return true;
}

protected:
std::vector<float> ToFeatures(const ProblemDescription& problem) const override
{
const bool isFwd = problem.GetDirection() == conv::Direction::Forward;
std::vector<float> features = {
static_cast<float>(isFwd ? problem.GetInChannels_() : problem.GetOutChannels_()),
static_cast<float>(isFwd ? problem.GetInHeight_() : problem.GetOutHeight_()),
static_cast<float>(isFwd ? problem.GetInWidth_() : problem.GetOutWidth_()),
static_cast<float>(isFwd ? problem.GetOutChannels_() : problem.GetInChannels_()),
static_cast<float>(isFwd ? problem.GetOutHeight_() : problem.GetInHeight_()),
static_cast<float>(isFwd ? problem.GetOutWidth_() : problem.GetInWidth_()),
static_cast<float>(problem.GetWeightsHeight_()),
static_cast<float>(problem.GetWeightsWidth_()),
static_cast<float>(problem.GetPadH()),
static_cast<float>(problem.GetPadW()),
static_cast<float>(problem.GetKernelStrideH()),
static_cast<float>(problem.GetKernelStrideW()),
static_cast<float>(problem.GetDilationH()),
static_cast<float>(problem.GetDilationW()),
static_cast<float>(problem.GetOutBatchSize_()),
static_cast<float>(metadata.EncodePrecision(problem.GetInDataType())),
static_cast<float>(metadata.EncodeDirection(problem.GetDirection())),
static_cast<float>(problem.GetGroupCount())};

// normalize
for(size_t i = 0; i < features.size(); ++i)
features[i] = (features[i] - metadata.features_mean[i]) / metadata.features_std[i];

return features;
}
};

std::unique_ptr<Model> GetModel(const std::string& device)
{
if(device == "gfx90a")
return std::make_unique<Gfx90aModel>();
else
return std::make_unique<Gfx908Model>();
}

std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
const ExecutionContext& ctx,
Expand All @@ -270,7 +369,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
auto db_res = db.FindRecord(static_cast<const conv::ProblemDescription&>(problem));
if(db_res)
{
MIOPEN_LOG_I2("Cached heuristic result found");
MIOPEN_LOG_I2("Cached heuristic (TunaNet) result found");
std::vector<uint64_t> db_sol(db_res->size());
// cast returned record to solver ids
std::transform(db_res->begin(), db_res->end(), db_sol.begin(), [](boost::any id) {
Expand All @@ -286,7 +385,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
return db_sol;
}

MIOPEN_LOG_I2("Evaluating Heuristic");
MIOPEN_LOG_I2("Evaluating TunaNet");

std::vector<float> res = model->Forward(problem);
std::vector<std::pair<int, float>> sort_res(res.size());
Expand Down Expand Up @@ -322,7 +421,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
std::stringstream ss;
for(auto& id : sol)
ss << solver::Id{id}.ToString() << " ID:" << id << ", ";
MIOPEN_LOG_I2("Heuristic Result: " << ss.str());
MIOPEN_LOG_I2("TunaNet Result: " << ss.str());
}
return sol;
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/miopen/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ struct ProblemDescriptionCompatTemporary
int GetOutWidth() const { return out_width; }
// int GetOutDepth() const { return out_depth; }
int GetBatchSize() const { return batch_sz; }
// int GetBias() const { return bias; }
int GetBias() const { return bias; }
// std::string GetInLayout() const { return in_layout; }
// std::string GetOutLayout() const { return out_layout; }
miopenDataType_t GetInDataType() const { return in_data_type; }
Expand Down
1 change: 1 addition & 0 deletions src/kernels/gfx90a.tn.model

Large diffs are not rendered by default.

Loading

0 comments on commit 411b345

Please sign in to comment.