Skip to content

Commit

Permalink
Find 2.0 problem fusing (#2466)
Browse files Browse the repository at this point in the history
  • Loading branch information
DrizztDoUrden authored Nov 10, 2023
1 parent 4b0b8b2 commit 4437720
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 13 deletions.
23 changes: 22 additions & 1 deletion include/miopen/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -5475,7 +5475,7 @@ MIOPEN_EXPORT miopenStatus_t miopenGetSolverIdConvAlgorithm(uint64_t solverId,
#ifdef MIOPEN_BETA_API

/*! @brief Initializes a problem object describing an activation operation.
* @note As of now there is no way to actually get any solution for this kind of problems
* @note As of now there is no way to actually get any solution for this kind of problems.
*
* @param problem Pointer to the problem to initialize
* @param operatorDesc Descriptor of the operator to be used
Expand All @@ -5487,6 +5487,27 @@ miopenCreateActivationProblem(miopenProblem_t* problem,
miopenActivationDescriptor_t operatorDesc,
miopenProblemDirection_t direction);

/*! @brief Fuse two problems into a single one. Problems can be either regular, or fused. No
* problems are disposed in the process, so the problem2 should be destroyed manually if it is not
* needed anymore.
* @example
* miopenProblem_t problem = makeSomeProblem1();
* miopenProblem_t problem2 = makeSomeProblem2();
* miopenProblem_t problem3 = makeSomeProblem3();
* miopenFuseProblems(problem, problem2);
* // Now problem contains {problem1, problem2}
* miopenFuseProblems(problem, problem3);
* // Now problem contains {problem1, problem2, problem3}
* miopenDestroyProblem(problem2);
* miopenDestroyProblem(problem3);
* @note As of now there is no way to actually get any solution for this kind of problems.
*
* @param problem1 The first problem to fuse. The result would be stored here.
* @param problem2 The second problem to fuse.
* @return miopenStatus_t
*/
MIOPEN_EXPORT miopenStatus_t miopenFuseProblems(miopenProblem_t problem1, miopenProblem_t problem2);

#endif

/** @} */
Expand Down
71 changes: 62 additions & 9 deletions src/api/find2_0_commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,20 @@
#include <miopen/type_name.hpp>

#include <nlohmann/json.hpp>
#include <boost/hof/match.hpp>

template <class OperationDescriptor>
static miopenStatus_t MakeProblem(miopenProblem_t* problem,
OperationDescriptor operatorDesc,
miopenProblemDirection_t direction)
{
return miopen::try_([&] {
miopen::deref(problem) = new miopen::Problem();
decltype(auto) problem_deref = miopen::deref(*problem);
decltype(auto) operator_deref = miopen::deref(operatorDesc);
miopen::deref(problem) = new miopen::ProblemContainer();
auto& container_deref = miopen::deref(*problem);

container_deref.item = miopen::Problem();
auto& problem_deref = boost::get<miopen::Problem>(container_deref.item);
auto& operator_deref = miopen::deref(operatorDesc);

problem_deref.SetOperatorDescriptor(operator_deref);
problem_deref.SetDirection(direction);
Expand All @@ -70,6 +74,42 @@ miopenStatus_t miopenCreateActivationProblem(miopenProblem_t* problem,
return MakeProblem(problem, operatorDesc, direction);
}

miopenStatus_t miopenFuseProblems(miopenProblem_t problem1, miopenProblem_t problem2)
{
MIOPEN_LOG_FUNCTION(problem1, problem2);
return miopen::try_([&] {
auto& problem1_deref = miopen::deref(problem1);

auto emplace_problem2 = [problem2](auto& problems) {
const auto impl2 = boost::hof::match(
[&](miopen::Problem& problem2_inner) { problems.emplace_back(problem2_inner); },
[&](const miopen::FusedProblem& problem2_inner) {
problems.reserve(problems.size() + problem2_inner.problems.size());
std::copy(problem2_inner.problems.begin(),
problem2_inner.problems.end(),
std::back_inserter(problems));
});

boost::apply_visitor(impl2, miopen::deref(problem2).item);
};

boost::apply_visitor(boost::hof::match(
[&](miopen::Problem& problem1_inner) {
auto tmp = miopen::FusedProblem{};
tmp.problems.reserve(2);
tmp.problems.emplace_back(problem1_inner);
emplace_problem2(tmp.problems);
problem1_deref.item = std::move(tmp);
},
[&](miopen::FusedProblem& problem1_inner) {
emplace_problem2(problem1_inner.problems);
}),
miopen::deref(problem1).item);

boost::get<miopen::FusedProblem&>(miopen::deref(problem1).item).PropagateDescriptors();
});
}

miopenStatus_t miopenDestroyProblem(miopenProblem_t problem)
{
MIOPEN_LOG_FUNCTION(problem);
Expand All @@ -82,8 +122,18 @@ miopenStatus_t miopenSetProblemTensorDescriptor(miopenProblem_t problem,
{
MIOPEN_LOG_FUNCTION(problem, id, descriptor);

return miopen::try_(
[&] { miopen::deref(problem).RegisterTensorDescriptor(id, miopen::deref(descriptor)); });
return miopen::try_([&] {
const auto impl = boost::hof::match(
[&](miopen::Problem& problem) {
problem.RegisterTensorDescriptor(id, miopen::deref(descriptor));
},
[&](const miopen::FusedProblem&) {
MIOPEN_THROW(miopenStatusBadParm,
"Attempt to set tensor descriptor of a fused problem");
});

boost::apply_visitor(impl, miopen::deref(problem).item);
});
}

miopenStatus_t miopenCreateFindOptions(miopenFindOptions_t* options)
Expand Down Expand Up @@ -163,15 +213,18 @@ miopenStatus_t miopenFindSolutions(miopenHandle_t handle,

return miopen::try_([&] {
auto& handle_deref = miopen::deref(handle);
const auto& problem_deref = miopen::deref(problem);
const auto& problem_deref = miopen::deref(problem).item;

problem_deref.LogDriverCommand();
boost::apply_visitor([](auto&& problem) { problem.LogDriverCommand(); }, problem_deref);

const auto& options_deref =
options == nullptr ? miopen::FindOptions{} : miopen::deref(options);

auto solutions_deref =
problem_deref.FindSolutions(handle_deref, options_deref, maxSolutions);
auto solutions_deref = boost::apply_visitor(
[&](auto&& problem) {
return problem.FindSolutions(handle_deref, options_deref, maxSolutions);
},
problem_deref);

for(auto i = 0; i < solutions_deref.size(); ++i)
miopen::deref(solutions + i) = new miopen::Solution{std::move(solutions_deref[i])};
Expand Down
70 changes: 67 additions & 3 deletions src/include/miopen/problem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ struct ProblemDescription;

using OperatorDescriptor = boost::variant<ConvolutionDescriptor, ActivationDescriptor>;

struct Problem : miopenProblem
struct Problem
{
friend struct FusedProblem;

Problem() = default;

const TensorDescriptor& GetTensorDescriptor(miopenTensorArgumentId_t name) const
Expand Down Expand Up @@ -87,9 +89,33 @@ struct Problem : miopenProblem
FindSolutions(Handle& handle, const FindOptions& options, std::size_t max_solutions) const;

conv::ProblemDescription AsConvolution() const;

activ::ProblemDescription AsActivation() const;

[[nodiscard]] miopenTensorArgumentId_t GetInputId() const;
[[nodiscard]] miopenTensorArgumentId_t GetOutputId() const;

[[nodiscard]] const TensorDescriptor& GetInput() const
{
return tensor_descriptors.at(GetInputId());
}

[[nodiscard]] const TensorDescriptor& GetOutput() const
{
return tensor_descriptors.at(GetOutputId());
}

[[nodiscard]] bool HasInput() const
{
return tensor_descriptors.find(GetInputId()) != tensor_descriptors.end();
}

[[nodiscard]] bool HasOutput() const
{
return tensor_descriptors.find(GetOutputId()) != tensor_descriptors.end();
}

void CalculateOutput();

const TensorDescriptor& GetTensorDescriptorChecked(miopenTensorArgumentId_t name,
const std::string& name_str) const;

Expand Down Expand Up @@ -125,6 +151,30 @@ struct Problem : miopenProblem
void LogDriverCommand(const ActivationDescriptor& descriptor) const;
};

struct FusedProblem
{
std::vector<Problem> problems;

void LogDriverCommand() const
{
// Not implemented, but silently
}

std::vector<Solution> FindSolutions(Handle& /*handle*/,
const FindOptions& /*options*/,
std::size_t /*max_solutions*/) const
{
MIOPEN_THROW(miopenStatusNotImplemented);
}

void PropagateDescriptors();
};

struct ProblemContainer : miopenProblem
{
boost::variant<Problem, FusedProblem> item;
};

} // namespace miopen

inline std::ostream& operator<<(std::ostream& stream, const miopen::Problem& problem)
Expand All @@ -134,4 +184,18 @@ inline std::ostream& operator<<(std::ostream& stream, const miopen::Problem& pro
return stream;
}

MIOPEN_DEFINE_OBJECT(miopenProblem, miopen::Problem);
inline std::ostream& operator<<(std::ostream& stream, const miopen::FusedProblem& problem)
{
// Todo: sane printing
stream << &problem;
return stream;
}

inline std::ostream& operator<<(std::ostream& stream, const miopen::ProblemContainer& problem)
{
// Todo: sane printing
stream << &problem;
return stream;
}

MIOPEN_DEFINE_OBJECT(miopenProblem, miopen::ProblemContainer);
54 changes: 54 additions & 0 deletions src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,4 +496,58 @@ void from_json(const nlohmann::json& json, Problem& problem)
primitive, &operator_json, &problem.operator_descriptor);
}

void Problem::CalculateOutput()
{
if(!HasInput())
return;

boost::apply_visitor(boost::hof::match(
[&](const ConvolutionDescriptor& conv) {
const auto& in = GetInput();
conv.GetForwardOutputTensor(
in,
GetTensorDescriptorChecked(miopenTensorConvolutionW,
"miopenTensorConvolutionW"),
in.GetType());
},
[&](const ActivationDescriptor&) {
RegisterTensorDescriptor(GetOutputId(), GetInput());
}),
operator_descriptor);
}

miopenTensorArgumentId_t Problem::GetInputId() const
{
return boost::apply_visitor(
boost::hof::match([](const ConvolutionDescriptor&) { return miopenTensorConvolutionX; },
[](const ActivationDescriptor&) { return miopenTensorActivationX; }),
operator_descriptor);
}

miopenTensorArgumentId_t Problem::GetOutputId() const
{
return boost::apply_visitor(
boost::hof::match([](const ConvolutionDescriptor&) { return miopenTensorConvolutionY; },
[](const ActivationDescriptor&) { return miopenTensorActivationY; }),
operator_descriptor);
}

void FusedProblem::PropagateDescriptors()
{
for(auto i = 0; i < problems.size(); ++i)
{
auto& cur = problems[i];

if(i > 0 && !cur.HasInput())
{
auto& prev = problems[i - 1];
if(prev.HasOutput())
cur.RegisterTensorDescriptor(cur.GetInputId(), prev.GetOutput());
}

if(cur.HasInput() && !cur.HasOutput())
cur.CalculateOutput();
}
}

} // namespace miopen

0 comments on commit 4437720

Please sign in to comment.