Skip to content

Commit

Permalink
Merge CK fwd mha FP16 solver (#3308)
Browse files Browse the repository at this point in the history
  • Loading branch information
BrianHarrisonAMD authored Oct 17, 2024
1 parent e006bc4 commit eecfb26
Show file tree
Hide file tree
Showing 15 changed files with 694 additions and 137 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:HIP_COMPILER_FLAGS=${HIP_COMPI
# HIP
if( MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU")
if(MIOPEN_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations)
find_package(composable_kernel 1.0.0 COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations device_mha_operations)
endif()
if( MIOPEN_BACKEND STREQUAL "HIPNOGPU")
set(MIOPEN_MODE_NOGPU 1)
Expand Down
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ set( MIOpen_Source
solver/layernorm/forward_layernorm2d_ck.cpp
solver/layernorm/forward_layernorm4d_ck.cpp
solver/layernorm/forward_t5layernorm.cpp
solver/mha/mha_ck_fa_v2_solver_forward.cpp
solver/mha/mha_solver_backward.cpp
solver/mha/mha_solver_forward.cpp
solver/multimarginloss/forward_multimarginloss.cpp
Expand Down Expand Up @@ -845,7 +846,7 @@ target_include_directories(MIOpen PUBLIC
)

if(MIOPEN_USE_COMPOSABLEKERNEL)
set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations hip::host)
set(MIOPEN_CK_LINK_FLAGS composable_kernel::device_other_operations composable_kernel::device_gemm_operations composable_kernel::device_conv_operations composable_kernel::device_reduction_operations composable_kernel::device_mha_operations hip::host)
endif()

if(WIN32)
Expand Down
22 changes: 22 additions & 0 deletions src/include/miopen/mha/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,28 @@ struct MhaBackward final : MhaSolver
MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override;
};

struct MhaCKFlashAttentionV2Forward final : MhaSolver
{
const std::string& SolverDbId() const override
{
return GetSolverDbId<MhaCKFlashAttentionV2Forward>();
}

MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext& context,
const miopen::mha::ProblemDescription& problem) const override;

MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext& context,
const miopen::mha::ProblemDescription& problem) const override;

MIOPEN_INTERNALS_EXPORT std::size_t
GetWorkspaceSize(const ExecutionContext& context,
const miopen::mha::ProblemDescription& problem) const override;

MIOPEN_INTERNALS_EXPORT bool MayNeedWorkspace() const override;
};

} // namespace mha

} // namespace solver
Expand Down
4 changes: 3 additions & 1 deletion src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,12 @@ Problem::FindSolutionsImpl(Handle& handle,

const auto algo = AlgorithmName{"Mha"};

static solver::mha::MhaCKFlashAttentionV2Forward mhaCKFAForwardSolver;
static solver::mha::MhaForward mhaForwardSolver;
static solver::mha::MhaBackward mhaBackwardSolver;

std::vector<solver::mha::MhaSolver*> solvers = {&mhaForwardSolver, &mhaBackwardSolver};
std::vector<solver::mha::MhaSolver*> solvers = {
&mhaCKFAForwardSolver, &mhaForwardSolver, &mhaBackwardSolver};

for(auto solver : solvers)
{
Expand Down
34 changes: 24 additions & 10 deletions src/solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,15 +400,32 @@ void Solution::RunImpl(Handle& handle,
return;
}

solver::mha::MhaForward mhaForward;
solver::mha::MhaBackward mhaBackward;
auto getSolution = [&](const ExecutionContext& ctx) {
auto solverId = GetSolver();
solver::mha::MhaForward mhaForward;
solver::mha::MhaBackward mhaBackward;
solver::mha::MhaCKFlashAttentionV2Forward ckMhaForward;

if(solverId == ckMhaForward.SolverDbId())
{
return ckMhaForward.GetSolution(ctx, problem_description);
}
else if(solverId == mhaForward.SolverDbId())
{
return mhaForward.GetSolution(ctx, problem_description);
}
else if(solverId == mhaBackward.SolverDbId())
{
return mhaBackward.GetSolution(ctx, problem_description);
}

MIOPEN_THROW("No MHA solver with matching SolverDbId of " + solverId.ToString());
};

if(!kernels.empty())
{
const auto ctx = ExecutionContext{&handle};
const auto mha_solution = GetSolver() == mhaForward.SolverDbId()
? mhaForward.GetSolution(ctx, problem_description)
: mhaBackward.GetSolution(ctx, problem_description);
const auto mha_solution = getSolution(ctx);
auto kernel_handles = std::vector<Kernel>{std::begin(kernels), std::end(kernels)};

invoker = (*mha_solution.invoker_factory)(kernel_handles);
Expand All @@ -425,11 +442,8 @@ void Solution::RunImpl(Handle& handle,
return;
}

auto ctx = ExecutionContext{&handle};

const auto mha_solution = GetSolver() == mhaForward.SolverDbId()
? mhaForward.GetSolution(ctx, problem_description)
: mhaBackward.GetSolution(ctx, problem_description);
auto ctx = ExecutionContext{&handle};
const auto mha_solution = getSolution(ctx);

invoker =
handle.PrepareInvoker(*mha_solution.invoker_factory, mha_solution.construction_params);
Expand Down
4 changes: 3 additions & 1 deletion src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,9 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
Primitive::MultiMarginLoss,
multimarginloss::MultiMarginLossForward{}.SolverDbId());

// IMPORTANT: New solvers should be added to the end of the function!
Register(registry, ++id, Primitive::Mha, mha::MhaCKFlashAttentionV2Forward{}.SolverDbId());
// IMPORTANT: New solvers should be added to the end of the function, and don't leave a white
// space between this comment and the newly registered solver(s)!
}

bool ThisSolverIsDeprecatedStatic::IsDisabled(const ExecutionContext& ctx)
Expand Down
Loading

0 comments on commit eecfb26

Please sign in to comment.