From 8091d1ce22bd19cdba035b80acd854349c679434 Mon Sep 17 00:00:00 2001 From: Daming Feng Date: Wed, 11 Oct 2023 20:36:04 +0000 Subject: [PATCH] address comments --- src/include/miopen/solver/ck_utility_common.hpp | 10 ++++++++++ .../conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp | 4 ++-- .../conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp | 4 ++-- .../conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp | 6 ++---- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/include/miopen/solver/ck_utility_common.hpp b/src/include/miopen/solver/ck_utility_common.hpp index f4f91fa228..003b067e50 100644 --- a/src/include/miopen/solver/ck_utility_common.hpp +++ b/src/include/miopen/solver/ck_utility_common.hpp @@ -63,6 +63,16 @@ static inline bool is_ck_supported_hardware(const Handle& handle) StartsWith(handle.GetDeviceName(), "gfx1102"); } +static inline bool is_conv_ck_supported_hardware(const std::string& device_name, bool is_wrw) +{ + auto res_wrw = StartsWith(device_name, "gfx908") || StartsWith(device_name, "gfx90a") || + StartsWith(device_name, "gfx940") || StartsWith(device_name, "gfx941") || + StartsWith(device_name, "gfx942"); + return is_wrw ? res_wrw + : (res_wrw || StartsWith(device_name, "gfx900") || + StartsWith(device_name, "gfx906")); +} + static inline bool is_support_amd_buffer_atomic_fadd(const std::string& device_name) { return StartsWith(device_name, "gfx908"); diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp index e576fdb0e5..58efe498ff 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp @@ -32,6 +32,7 @@ #include #include #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL +#include #include #endif #include @@ -321,8 +322,7 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable( return false; if(!problem.IsLayoutNHWC()) return false; - const std::string& arch = ctx.GetStream().GetDeviceName(); - if(miopen::StartsWith(arch, "gfx11") || miopen::StartsWith(arch, "gfx10")) + if(!ck_utility::is_conv_ck_supported_hardware(ctx.GetStream().GetDeviceName(), false)) return false; switch(problem.GetInDataType()) { diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp index e772b1ea0d..e7a44456b9 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp @@ -32,6 +32,7 @@ #include #include #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL +#include #include #endif #include @@ -319,8 +320,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable( return false; if(!problem.IsLayoutNHWC()) return false; - const std::string& arch = ctx.GetStream().GetDeviceName(); - if(miopen::StartsWith(arch, "gfx11") || miopen::StartsWith(arch, "gfx10")) + if(!ck_utility::is_conv_ck_supported_hardware(ctx.GetStream().GetDeviceName(), false)) return false; switch(problem.GetInDataType()) { diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp index 0d24c8c602..d0236e4f42 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp @@ -32,6 +32,7 @@ #include #include #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL +#include #include #endif #include @@ -315,10 +316,7 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable( return false; if(!problem.IsLayoutNHWC()) return false; - const std::string& arch = ctx.GetStream().GetDeviceName(); - if(miopen::StartsWith(arch, "gfx11") || miopen::StartsWith(arch, "gfx10")) - return false; - if(arch == "gfx906" || arch == "gfx900") + if(!ck_utility::is_conv_ck_supported_hardware(ctx.GetStream().GetDeviceName(), true)) return false; switch(problem.GetInDataType()) {