Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TunaNet Integration: MI250x #2421

Merged
merged 11 commits into from
Oct 24, 2023
111 changes: 106 additions & 5 deletions src/conv/heuristics/ai_heuristics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class Model
virtual std::vector<float> ToFeatures(const ProblemDescription& problem) const = 0;
};

class Gfx908Model : public Model
class Gfx908Model final : public Model
{
public:
Gfx908Model() : Model("gfx908") {}
Expand Down Expand Up @@ -255,7 +255,108 @@ class Gfx908Model : public Model
}
};

std::unique_ptr<Model> GetModel(const std::string&) { return std::make_unique<Gfx908Model>(); }
class Gfx90aModel final : public Model
{
public:
Gfx90aModel() : Model("gfx90a") {}
bool IsProblemSupported(const ProblemDescription& problem,
const ExecutionContext& ctx) const override
{
// check if problem is of the kind TunaNet was trained to handle
if(!problem.Is2d())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Problem not 2D");
return false;
}
if(problem.GetInLayout() != "NCHW")
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Layout not supported");
return false;
}
if(problem.GetKernelStrideH() != problem.GetKernelStrideW())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Stride must be equal along all axes");
return false;
}
if(problem.GetDilationH() != problem.GetDilationW())
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Dilation must be 1");
return false;
}
if(problem.GetBias() != 0)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Bias must be 0");
return false;
}
const auto data_type = problem.GetInDataType();
if(data_type != miopenFloat && data_type != miopenHalf && data_type != miopenBFloat16)
{
MIOPEN_LOG_I2("TunaNet Inapplicable: Unsupported data type");
return false;
}

// check if the context is s.t. no solver TunaNet may predict would be applicable
size_t applicable_solvers = 0;
for(const auto& solver_name : metadata.solver_map)
{
auto solver_id = solver::Id{solver_name.second};
auto solver = solver_id.GetSolver();
if(solver.IsApplicable(ctx, problem))
{
applicable_solvers++;
break;
}
}
if(applicable_solvers == 0)
Comment on lines +299 to +310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably it's a bit wordy, but from a high-level point of view if show that we want to find a sovler which IsApplicable.

Suggested change
size_t applicable_solvers = 0;
for(const auto& solver_name : metadata.solver_map)
{
auto solver_id = solver::Id{solver_name.second};
auto solver = solver_id.GetSolver();
if(solver.IsApplicable(ctx, problem))
{
applicable_solvers++;
break;
}
}
if(applicable_solvers == 0)
if(std::find_if(metadata.solver_map.begin(),
metadata.solver_map.end(),
[&ctx, &problem](const auto& solver_name)
{
return solver::Id{solver_name.second}.GetSolver().IsApplicable(ctx, problem);
}) == metadata.solver_map.end())

{
MIOPEN_LOG_I2("TunaNet Inapplicable: No solver that TunaNet may predict applies");
return false;
}
MIOPEN_LOG_I2("TunaNet Applicable");
return true;
}

protected:
std::vector<float> ToFeatures(const ProblemDescription& problem) const override
{
const bool isFwd = problem.GetDirection() == conv::Direction::Forward;
std::vector<float> features = {
static_cast<float>(isFwd ? problem.GetInChannels_() : problem.GetOutChannels_()),
static_cast<float>(isFwd ? problem.GetInHeight_() : problem.GetOutHeight_()),
static_cast<float>(isFwd ? problem.GetInWidth_() : problem.GetOutWidth_()),
static_cast<float>(isFwd ? problem.GetOutChannels_() : problem.GetInChannels_()),
static_cast<float>(isFwd ? problem.GetOutHeight_() : problem.GetInHeight_()),
static_cast<float>(isFwd ? problem.GetOutWidth_() : problem.GetInWidth_()),
static_cast<float>(problem.GetWeightsHeight_()),
static_cast<float>(problem.GetWeightsWidth_()),
static_cast<float>(problem.GetPadH()),
static_cast<float>(problem.GetPadW()),
static_cast<float>(problem.GetKernelStrideH()),
static_cast<float>(problem.GetKernelStrideW()),
static_cast<float>(problem.GetDilationH()),
static_cast<float>(problem.GetDilationW()),
static_cast<float>(problem.GetOutBatchSize_()),
static_cast<float>(metadata.EncodePrecision(problem.GetInDataType())),
static_cast<float>(metadata.EncodeDirection(problem.GetDirection())),
static_cast<float>(problem.GetGroupCount())};

// normalize
for(size_t i = 0; i < features.size(); ++i)
features[i] = (features[i] - metadata.features_mean[i]) / metadata.features_std[i];

return features;
}
};

std::unique_ptr<Model> GetModel(const std::string& device)
{
if(device == "gfx90a")
return std::make_unique<Gfx90aModel>();
else if(device == "gfx908")
return std::make_unique<Gfx908Model>();
else
return nullptr;
}

std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
const ExecutionContext& ctx,
Expand All @@ -270,7 +371,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
auto db_res = db.FindRecord(static_cast<const conv::ProblemDescription&>(problem));
if(db_res)
{
MIOPEN_LOG_I2("Cached heuristic result found");
MIOPEN_LOG_I2("Cached heuristic (TunaNet) result found");
std::vector<uint64_t> db_sol(db_res->size());
// cast returned record to solver ids
std::transform(db_res->begin(), db_res->end(), db_sol.begin(), [](boost::any id) {
Expand All @@ -286,7 +387,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
return db_sol;
}

MIOPEN_LOG_I2("Evaluating Heuristic");
MIOPEN_LOG_I2("Evaluating TunaNet");

std::vector<float> res = model->Forward(problem);
std::vector<std::pair<int, float>> sort_res(res.size());
Expand Down Expand Up @@ -322,7 +423,7 @@ std::vector<uint64_t> PredictSolver(const ProblemDescription& problem,
std::stringstream ss;
for(auto& id : sol)
ss << solver::Id{id}.ToString() << " ID:" << id << ", ";
MIOPEN_LOG_I2("Heuristic Result: " << ss.str());
MIOPEN_LOG_I2("TunaNet Result: " << ss.str());
}
return sol;
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/miopen/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ struct ProblemDescriptionCompatTemporary
int GetOutWidth() const { return out_width; }
// int GetOutDepth() const { return out_depth; }
int GetBatchSize() const { return batch_sz; }
// int GetBias() const { return bias; }
int GetBias() const { return bias; }
// std::string GetInLayout() const { return in_layout; }
// std::string GetOutLayout() const { return out_layout; }
miopenDataType_t GetInDataType() const { return in_data_type; }
Expand Down
1 change: 1 addition & 0 deletions src/kernels/gfx90a.tn.model

Large diffs are not rendered by default.

Loading