Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Find 2.0 problem fusing #2466

Merged
merged 3 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion include/miopen/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -5475,7 +5475,7 @@ MIOPEN_EXPORT miopenStatus_t miopenGetSolverIdConvAlgorithm(uint64_t solverId,
#ifdef MIOPEN_BETA_API

/*! @brief Initializes a problem object describing an activation operation.
* @note As of now there is no way to actually get any solution for this kind of problems
* @note As of now there is no way to actually get any solution for this kind of problems.
*
* @param problem Pointer to the problem to initialize
* @param operatorDesc Descriptor of the operator to be used
Expand All @@ -5487,6 +5487,27 @@ miopenCreateActivationProblem(miopenProblem_t* problem,
miopenActivationDescriptor_t operatorDesc,
miopenProblemDirection_t direction);

/*! @brief Fuse two problems into a single one. Problems can be either regular, or fused. No
* problems are disposed in the process, so the problem2 should be destroyed manually if it is not
* needed anymore.
* @example
* miopenProblem_t problem = makeSomeProblem1();
* miopenProblem_t problem2 = makeSomeProblem2();
* miopenProblem_t problem3 = makeSomeProblem3();
* miopenFuseProblems(problem, problem2);
* // Now problem contains {problem1, problem2}
* miopenFuseProblems(problem, problem3);
* // Now problem contains {problem1, problem2, problem3}
* miopenDestroyProblem(problem2);
* miopenDestroyProblem(problem3);
* @note As of now there is no way to actually get any solution for this kind of problems.
*
* @param problem1 The first problem to fuse. The result would be stored here.
* @param problem2 The second problem to fuse.
* @return miopenStatus_t
*/
MIOPEN_EXPORT miopenStatus_t miopenFuseProblems(miopenProblem_t problem1, miopenProblem_t problem2);

#endif

/** @} */
Expand Down
71 changes: 62 additions & 9 deletions src/api/find2_0_commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,20 @@
#include <miopen/type_name.hpp>

#include <nlohmann/json.hpp>
#include <boost/hof/match.hpp>

template <class OperationDescriptor>
static miopenStatus_t MakeProblem(miopenProblem_t* problem,
OperationDescriptor operatorDesc,
miopenProblemDirection_t direction)
{
return miopen::try_([&] {
miopen::deref(problem) = new miopen::Problem();
decltype(auto) problem_deref = miopen::deref(*problem);
decltype(auto) operator_deref = miopen::deref(operatorDesc);
miopen::deref(problem) = new miopen::ProblemContainer();
auto& container_deref = miopen::deref(*problem);

container_deref.item = miopen::Problem();
auto& problem_deref = boost::get<miopen::Problem>(container_deref.item);
auto& operator_deref = miopen::deref(operatorDesc);

problem_deref.SetOperatorDescriptor(operator_deref);
problem_deref.SetDirection(direction);
Expand All @@ -70,6 +74,42 @@ miopenStatus_t miopenCreateActivationProblem(miopenProblem_t* problem,
return MakeProblem(problem, operatorDesc, direction);
}

miopenStatus_t miopenFuseProblems(miopenProblem_t problem1, miopenProblem_t problem2)
{
MIOPEN_LOG_FUNCTION(problem1, problem2);
return miopen::try_([&] {
auto& problem1_deref = miopen::deref(problem1);

auto emplace_problem2 = [problem2](auto& problems) {
const auto impl2 = boost::hof::match(
[&](miopen::Problem& problem2_inner) { problems.emplace_back(problem2_inner); },
[&](const miopen::FusedProblem& problem2_inner) {
problems.reserve(problems.size() + problem2_inner.problems.size());
std::copy(problem2_inner.problems.begin(),
problem2_inner.problems.end(),
std::back_inserter(problems));
});

boost::apply_visitor(impl2, miopen::deref(problem2).item);
};

boost::apply_visitor(boost::hof::match(
[&](miopen::Problem& problem1_inner) {
auto tmp = miopen::FusedProblem{};
tmp.problems.reserve(2);
tmp.problems.emplace_back(problem1_inner);
emplace_problem2(tmp.problems);
problem1_deref.item = std::move(tmp);
},
[&](miopen::FusedProblem& problem1_inner) {
emplace_problem2(problem1_inner.problems);
}),
miopen::deref(problem1).item);

boost::get<miopen::FusedProblem&>(miopen::deref(problem1).item).PropagateDescriptors();
});
}

miopenStatus_t miopenDestroyProblem(miopenProblem_t problem)
{
MIOPEN_LOG_FUNCTION(problem);
Expand All @@ -82,8 +122,18 @@ miopenStatus_t miopenSetProblemTensorDescriptor(miopenProblem_t problem,
{
MIOPEN_LOG_FUNCTION(problem, id, descriptor);

return miopen::try_(
[&] { miopen::deref(problem).RegisterTensorDescriptor(id, miopen::deref(descriptor)); });
return miopen::try_([&] {
const auto impl = boost::hof::match(
[&](miopen::Problem& problem) {
problem.RegisterTensorDescriptor(id, miopen::deref(descriptor));
},
[&](const miopen::FusedProblem&) {
MIOPEN_THROW(miopenStatusBadParm,
"Attempt to set tensor descriptor of a fused problem");
});

boost::apply_visitor(impl, miopen::deref(problem).item);
});
}

miopenStatus_t miopenCreateFindOptions(miopenFindOptions_t* options)
Expand Down Expand Up @@ -163,15 +213,18 @@ miopenStatus_t miopenFindSolutions(miopenHandle_t handle,

return miopen::try_([&] {
auto& handle_deref = miopen::deref(handle);
const auto& problem_deref = miopen::deref(problem);
const auto& problem_deref = miopen::deref(problem).item;

problem_deref.LogDriverCommand();
boost::apply_visitor([](auto&& problem) { problem.LogDriverCommand(); }, problem_deref);

const auto& options_deref =
options == nullptr ? miopen::FindOptions{} : miopen::deref(options);

auto solutions_deref =
problem_deref.FindSolutions(handle_deref, options_deref, maxSolutions);
auto solutions_deref = boost::apply_visitor(
[&](auto&& problem) {
return problem.FindSolutions(handle_deref, options_deref, maxSolutions);
},
problem_deref);

for(auto i = 0; i < solutions_deref.size(); ++i)
miopen::deref(solutions + i) = new miopen::Solution{std::move(solutions_deref[i])};
Expand Down
70 changes: 67 additions & 3 deletions src/include/miopen/problem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ struct ProblemDescription;

using OperatorDescriptor = boost::variant<ConvolutionDescriptor, ActivationDescriptor>;

struct Problem : miopenProblem
struct Problem
{
friend struct FusedProblem;

Problem() = default;

const TensorDescriptor& GetTensorDescriptor(miopenTensorArgumentId_t name) const
Expand Down Expand Up @@ -87,9 +89,33 @@ struct Problem : miopenProblem
FindSolutions(Handle& handle, const FindOptions& options, std::size_t max_solutions) const;

conv::ProblemDescription AsConvolution() const;

activ::ProblemDescription AsActivation() const;

[[nodiscard]] miopenTensorArgumentId_t GetInputId() const;
[[nodiscard]] miopenTensorArgumentId_t GetOutputId() const;

[[nodiscard]] const TensorDescriptor& GetInput() const
{
return tensor_descriptors.at(GetInputId());
}

[[nodiscard]] const TensorDescriptor& GetOutput() const
{
return tensor_descriptors.at(GetOutputId());
}

[[nodiscard]] bool HasInput() const
{
return tensor_descriptors.find(GetInputId()) != tensor_descriptors.end();
}

[[nodiscard]] bool HasOutput() const
{
return tensor_descriptors.find(GetOutputId()) != tensor_descriptors.end();
}

void CalculateOutput();

const TensorDescriptor& GetTensorDescriptorChecked(miopenTensorArgumentId_t name,
const std::string& name_str) const;

Expand Down Expand Up @@ -125,6 +151,30 @@ struct Problem : miopenProblem
void LogDriverCommand(const ActivationDescriptor& descriptor) const;
};

struct FusedProblem
{
std::vector<Problem> problems;

void LogDriverCommand() const
{
// Not implemented, but silently
}

std::vector<Solution> FindSolutions(Handle& /*handle*/,
const FindOptions& /*options*/,
std::size_t /*max_solutions*/) const
{
MIOPEN_THROW(miopenStatusNotImplemented);
}

void PropagateDescriptors();
};

struct ProblemContainer : miopenProblem
{
boost::variant<Problem, FusedProblem> item;
};

} // namespace miopen

inline std::ostream& operator<<(std::ostream& stream, const miopen::Problem& problem)
Expand All @@ -134,4 +184,18 @@ inline std::ostream& operator<<(std::ostream& stream, const miopen::Problem& pro
return stream;
}

MIOPEN_DEFINE_OBJECT(miopenProblem, miopen::Problem);
inline std::ostream& operator<<(std::ostream& stream, const miopen::FusedProblem& problem)
{
// Todo: sane printing
stream << &problem;
return stream;
}

inline std::ostream& operator<<(std::ostream& stream, const miopen::ProblemContainer& problem)
{
// Todo: sane printing
stream << &problem;
return stream;
}

MIOPEN_DEFINE_OBJECT(miopenProblem, miopen::ProblemContainer);
54 changes: 54 additions & 0 deletions src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,4 +497,58 @@ void from_json(const nlohmann::json& json, Problem& problem)
primitive, &operator_json, &problem.operator_descriptor);
}

void Problem::CalculateOutput()
{
if(!HasInput())
return;

boost::apply_visitor(boost::hof::match(
[&](const ConvolutionDescriptor& conv) {
const auto& in = GetInput();
conv.GetForwardOutputTensor(
in,
GetTensorDescriptorChecked(miopenTensorConvolutionW,
"miopenTensorConvolutionW"),
in.GetType());
},
[&](const ActivationDescriptor&) {
RegisterTensorDescriptor(GetOutputId(), GetInput());
}),
operator_descriptor);
}

miopenTensorArgumentId_t Problem::GetInputId() const
{
return boost::apply_visitor(
boost::hof::match([](const ConvolutionDescriptor&) { return miopenTensorConvolutionX; },
[](const ActivationDescriptor&) { return miopenTensorActivationX; }),
operator_descriptor);
}

miopenTensorArgumentId_t Problem::GetOutputId() const
{
return boost::apply_visitor(
boost::hof::match([](const ConvolutionDescriptor&) { return miopenTensorConvolutionY; },
[](const ActivationDescriptor&) { return miopenTensorActivationY; }),
operator_descriptor);
}

void FusedProblem::PropagateDescriptors()
{
for(auto i = 0; i < problems.size(); ++i)
{
auto& cur = problems[i];

if(i > 0 && !cur.HasInput())
{
auto& prev = problems[i - 1];
if(prev.HasOutput())
cur.RegisterTensorDescriptor(cur.GetInputId(), prev.GetOutput());
}

if(cur.HasInput() && !cur.HasOutput())
cur.CalculateOutput();
}
}

} // namespace miopen