diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 83b131cbd0..9c47d62e81 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -5475,7 +5475,7 @@ MIOPEN_EXPORT miopenStatus_t miopenGetSolverIdConvAlgorithm(uint64_t solverId, #ifdef MIOPEN_BETA_API /*! @brief Initializes a problem object describing an activation operation. - * @note As of now there is no way to actually get any solution for this kind of problems + * @note As of now there is no way to actually get any solution for this kind of problems. * * @param problem Pointer to the problem to initialize * @param operatorDesc Descriptor of the operator to be used @@ -5487,6 +5487,27 @@ miopenCreateActivationProblem(miopenProblem_t* problem, miopenActivationDescriptor_t operatorDesc, miopenProblemDirection_t direction); +/*! @brief Fuse two problems into a single one. Problems can be either regular, or fused. No + * problems are disposed in the process, so the problem2 should be destroyed manually if it is not + * needed anymore. + * @example + * miopenProblem_t problem = makeSomeProblem1(); + * miopenProblem_t problem2 = makeSomeProblem2(); + * miopenProblem_t problem3 = makeSomeProblem3(); + * miopenFuseProblems(problem, problem2); + * // Now problem contains {problem1, problem2} + * miopenFuseProblems(problem, problem3); + * // Now problem contains {problem1, problem2, problem3} + * miopenDestroyProblem(problem2); + * miopenDestroyProblem(problem3); + * @note As of now there is no way to actually get any solution for this kind of problems. + * + * @param problem1 The first problem to fuse. The result would be stored here. + * @param problem2 The second problem to fuse. + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenFuseProblems(miopenProblem_t problem1, miopenProblem_t problem2); + #endif /** @} */ diff --git a/src/api/find2_0_commons.cpp b/src/api/find2_0_commons.cpp index 3d0bf4805d..3f864a930c 100644 --- a/src/api/find2_0_commons.cpp +++ b/src/api/find2_0_commons.cpp @@ -37,6 +37,7 @@ #include #include +#include template static miopenStatus_t MakeProblem(miopenProblem_t* problem, @@ -44,9 +45,12 @@ static miopenStatus_t MakeProblem(miopenProblem_t* problem, miopenProblemDirection_t direction) { return miopen::try_([&] { - miopen::deref(problem) = new miopen::Problem(); - decltype(auto) problem_deref = miopen::deref(*problem); - decltype(auto) operator_deref = miopen::deref(operatorDesc); + miopen::deref(problem) = new miopen::ProblemContainer(); + auto& container_deref = miopen::deref(*problem); + + container_deref.item = miopen::Problem(); + auto& problem_deref = boost::get(container_deref.item); + auto& operator_deref = miopen::deref(operatorDesc); problem_deref.SetOperatorDescriptor(operator_deref); problem_deref.SetDirection(direction); @@ -70,6 +74,42 @@ miopenStatus_t miopenCreateActivationProblem(miopenProblem_t* problem, return MakeProblem(problem, operatorDesc, direction); } +miopenStatus_t miopenFuseProblems(miopenProblem_t problem1, miopenProblem_t problem2) +{ + MIOPEN_LOG_FUNCTION(problem1, problem2); + return miopen::try_([&] { + auto& problem1_deref = miopen::deref(problem1); + + auto emplace_problem2 = [problem2](auto& problems) { + const auto impl2 = boost::hof::match( + [&](miopen::Problem& problem2_inner) { problems.emplace_back(problem2_inner); }, + [&](const miopen::FusedProblem& problem2_inner) { + problems.reserve(problems.size() + problem2_inner.problems.size()); + std::copy(problem2_inner.problems.begin(), + problem2_inner.problems.end(), + std::back_inserter(problems)); + }); + + boost::apply_visitor(impl2, miopen::deref(problem2).item); + }; + + boost::apply_visitor(boost::hof::match( + [&](miopen::Problem& problem1_inner) { + auto tmp = miopen::FusedProblem{}; + tmp.problems.reserve(2); + tmp.problems.emplace_back(problem1_inner); + emplace_problem2(tmp.problems); + problem1_deref.item = std::move(tmp); + }, + [&](miopen::FusedProblem& problem1_inner) { + emplace_problem2(problem1_inner.problems); + }), + miopen::deref(problem1).item); + + boost::get(miopen::deref(problem1).item).PropagateDescriptors(); + }); +} + miopenStatus_t miopenDestroyProblem(miopenProblem_t problem) { MIOPEN_LOG_FUNCTION(problem); @@ -82,8 +122,18 @@ miopenStatus_t miopenSetProblemTensorDescriptor(miopenProblem_t problem, { MIOPEN_LOG_FUNCTION(problem, id, descriptor); - return miopen::try_( - [&] { miopen::deref(problem).RegisterTensorDescriptor(id, miopen::deref(descriptor)); }); + return miopen::try_([&] { + const auto impl = boost::hof::match( + [&](miopen::Problem& problem) { + problem.RegisterTensorDescriptor(id, miopen::deref(descriptor)); + }, + [&](const miopen::FusedProblem&) { + MIOPEN_THROW(miopenStatusBadParm, + "Attempt to set tensor descriptor of a fused problem"); + }); + + boost::apply_visitor(impl, miopen::deref(problem).item); + }); } miopenStatus_t miopenCreateFindOptions(miopenFindOptions_t* options) @@ -163,15 +213,18 @@ miopenStatus_t miopenFindSolutions(miopenHandle_t handle, return miopen::try_([&] { auto& handle_deref = miopen::deref(handle); - const auto& problem_deref = miopen::deref(problem); + const auto& problem_deref = miopen::deref(problem).item; - problem_deref.LogDriverCommand(); + boost::apply_visitor([](auto&& problem) { problem.LogDriverCommand(); }, problem_deref); const auto& options_deref = options == nullptr ? miopen::FindOptions{} : miopen::deref(options); - auto solutions_deref = - problem_deref.FindSolutions(handle_deref, options_deref, maxSolutions); + auto solutions_deref = boost::apply_visitor( + [&](auto&& problem) { + return problem.FindSolutions(handle_deref, options_deref, maxSolutions); + }, + problem_deref); for(auto i = 0; i < solutions_deref.size(); ++i) miopen::deref(solutions + i) = new miopen::Solution{std::move(solutions_deref[i])}; diff --git a/src/include/miopen/problem.hpp b/src/include/miopen/problem.hpp index fdb2ad0cec..150fc1d4d4 100644 --- a/src/include/miopen/problem.hpp +++ b/src/include/miopen/problem.hpp @@ -58,8 +58,10 @@ struct ProblemDescription; using OperatorDescriptor = boost::variant; -struct Problem : miopenProblem +struct Problem { + friend struct FusedProblem; + Problem() = default; const TensorDescriptor& GetTensorDescriptor(miopenTensorArgumentId_t name) const @@ -87,9 +89,33 @@ struct Problem : miopenProblem FindSolutions(Handle& handle, const FindOptions& options, std::size_t max_solutions) const; conv::ProblemDescription AsConvolution() const; - activ::ProblemDescription AsActivation() const; + [[nodiscard]] miopenTensorArgumentId_t GetInputId() const; + [[nodiscard]] miopenTensorArgumentId_t GetOutputId() const; + + [[nodiscard]] const TensorDescriptor& GetInput() const + { + return tensor_descriptors.at(GetInputId()); + } + + [[nodiscard]] const TensorDescriptor& GetOutput() const + { + return tensor_descriptors.at(GetOutputId()); + } + + [[nodiscard]] bool HasInput() const + { + return tensor_descriptors.find(GetInputId()) != tensor_descriptors.end(); + } + + [[nodiscard]] bool HasOutput() const + { + return tensor_descriptors.find(GetOutputId()) != tensor_descriptors.end(); + } + + void CalculateOutput(); + const TensorDescriptor& GetTensorDescriptorChecked(miopenTensorArgumentId_t name, const std::string& name_str) const; @@ -125,6 +151,30 @@ struct Problem : miopenProblem void LogDriverCommand(const ActivationDescriptor& descriptor) const; }; +struct FusedProblem +{ + std::vector problems; + + void LogDriverCommand() const + { + // Not implemented, but silently + } + + std::vector FindSolutions(Handle& /*handle*/, + const FindOptions& /*options*/, + std::size_t /*max_solutions*/) const + { + MIOPEN_THROW(miopenStatusNotImplemented); + } + + void PropagateDescriptors(); +}; + +struct ProblemContainer : miopenProblem +{ + boost::variant item; +}; + } // namespace miopen inline std::ostream& operator<<(std::ostream& stream, const miopen::Problem& problem) @@ -134,4 +184,18 @@ inline std::ostream& operator<<(std::ostream& stream, const miopen::Problem& pro return stream; } -MIOPEN_DEFINE_OBJECT(miopenProblem, miopen::Problem); +inline std::ostream& operator<<(std::ostream& stream, const miopen::FusedProblem& problem) +{ + // Todo: sane printing + stream << &problem; + return stream; +} + +inline std::ostream& operator<<(std::ostream& stream, const miopen::ProblemContainer& problem) +{ + // Todo: sane printing + stream << &problem; + return stream; +} + +MIOPEN_DEFINE_OBJECT(miopenProblem, miopen::ProblemContainer); diff --git a/src/problem.cpp b/src/problem.cpp index 99eb994921..ac7c15b4be 100644 --- a/src/problem.cpp +++ b/src/problem.cpp @@ -496,4 +496,58 @@ void from_json(const nlohmann::json& json, Problem& problem) primitive, &operator_json, &problem.operator_descriptor); } +void Problem::CalculateOutput() +{ + if(!HasInput()) + return; + + boost::apply_visitor(boost::hof::match( + [&](const ConvolutionDescriptor& conv) { + const auto& in = GetInput(); + conv.GetForwardOutputTensor( + in, + GetTensorDescriptorChecked(miopenTensorConvolutionW, + "miopenTensorConvolutionW"), + in.GetType()); + }, + [&](const ActivationDescriptor&) { + RegisterTensorDescriptor(GetOutputId(), GetInput()); + }), + operator_descriptor); +} + +miopenTensorArgumentId_t Problem::GetInputId() const +{ + return boost::apply_visitor( + boost::hof::match([](const ConvolutionDescriptor&) { return miopenTensorConvolutionX; }, + [](const ActivationDescriptor&) { return miopenTensorActivationX; }), + operator_descriptor); +} + +miopenTensorArgumentId_t Problem::GetOutputId() const +{ + return boost::apply_visitor( + boost::hof::match([](const ConvolutionDescriptor&) { return miopenTensorConvolutionY; }, + [](const ActivationDescriptor&) { return miopenTensorActivationY; }), + operator_descriptor); +} + +void FusedProblem::PropagateDescriptors() +{ + for(auto i = 0; i < problems.size(); ++i) + { + auto& cur = problems[i]; + + if(i > 0 && !cur.HasInput()) + { + auto& prev = problems[i - 1]; + if(prev.HasOutput()) + cur.RegisterTensorDescriptor(cur.GetInputId(), prev.GetOutput()); + } + + if(cur.HasInput() && !cur.HasOutput()) + cur.CalculateOutput(); + } +} + } // namespace miopen