Skip to content

Commit

Permalink
Rename function and return invalid solution instead of throwing an er…
Browse files Browse the repository at this point in the history
…ror (ROCm#2457)
  • Loading branch information
CAHEK7 authored Oct 18, 2023
1 parent 23ac5ab commit 09fca66
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 19 deletions.
7 changes: 5 additions & 2 deletions src/include/miopen/solver/implicitgemm_ck_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,16 @@ template <typename DeviceOpType,
typename CKArgsType,
typename CastType,
typename ProblemDescriptionType = ProblemDescription>
ConvSolution InitInvokerFactory(const ProblemDescriptionType& problem, const std::string& kernel_id)
ConvSolution MakeInvokerFactory(const ProblemDescriptionType& problem, const std::string& kernel_id)
{
auto conv_ptrs = DeviceOpType::GetInstances();
auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id);

if(ptr_iter == conv_ptrs.end())
MIOPEN_THROW("PerformanceConfig kernel '" + kernel_id + "' does not exist");
{
MIOPEN_LOG_E("PerformanceConfig kernel '" + kernel_id + "' does not exist.");
return {miopenStatusInvalidValue};
}

ConvSolution result;
result.invoker_factory =
Expand Down
6 changes: 3 additions & 3 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,13 @@ ConvSolution ConvHipImplicitGemm3DGroupBwdXdlops::GetSolution(
switch(problem.GetInDataType())
{
case miopenInt8:
return InitInvokerFactory<DeviceOpGBwdPtrs<int8_t>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpGBwdPtrs<int8_t>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenHalf:
return InitInvokerFactory<DeviceOpGBwdPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpGBwdPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenFloat:
return InitInvokerFactory<DeviceOpGBwdPtrs<float>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpGBwdPtrs<float>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenInt32:
case miopenInt8x4: // Support discontinued.
Expand Down
6 changes: 3 additions & 3 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,13 @@ ConvSolution ConvHipImplicitGemm3DGroupFwdXdlops::GetSolution(
switch(problem.GetInDataType())
{
case miopenInt8:
return InitInvokerFactory<DeviceOpGFwdPtrs<int8_t>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpGFwdPtrs<int8_t>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenHalf:
return InitInvokerFactory<DeviceOpGFwdPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpGFwdPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenFloat:
return InitInvokerFactory<DeviceOpGFwdPtrs<float>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpGFwdPtrs<float>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenInt32:
case miopenInt8x4: // Support discontinued.
Expand Down
6 changes: 3 additions & 3 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,13 @@ ConvSolution ConvHipImplicitGemm3DGroupWrwXdlops::GetSolution(
switch(problem.GetInDataType())
{
case miopenInt8:
return InitInvokerFactory<DeviceOpGWrwPtrs<int8_t>, CKArgs, conv::WrWInvokeParams>(
return MakeInvokerFactory<DeviceOpGWrwPtrs<int8_t>, CKArgs, conv::WrWInvokeParams>(
problem, config.kernel_id);
case miopenHalf:
return InitInvokerFactory<DeviceOpGWrwPtrs<ck::half_t>, CKArgs, conv::WrWInvokeParams>(
return MakeInvokerFactory<DeviceOpGWrwPtrs<ck::half_t>, CKArgs, conv::WrWInvokeParams>(
problem, config.kernel_id);
case miopenFloat:
return InitInvokerFactory<DeviceOpGWrwPtrs<float>, CKArgs, conv::WrWInvokeParams>(
return MakeInvokerFactory<DeviceOpGWrwPtrs<float>, CKArgs, conv::WrWInvokeParams>(
problem, config.kernel_id);
case miopenInt32:
case miopenInt8x4: // Support discontinued.
Expand Down
4 changes: 2 additions & 2 deletions src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,10 @@ ConvSolution ConvHipImplicitGemmBwdXdlops::GetSolution(
switch(problem.GetInDataType())
{
case miopenHalf:
return InitInvokerFactory<DeviceOpBwdPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpBwdPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenFloat:
return InitInvokerFactory<DeviceOpBwdPtrs<float>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpBwdPtrs<float>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenInt8:
case miopenInt32:
Expand Down
6 changes: 3 additions & 3 deletions src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,13 @@ ConvSolution ConvHipImplicitGemmFwdXdlops::GetSolution(
switch(problem.GetInDataType())
{
case miopenInt8:
return InitInvokerFactory<DeviceOpPtrs<int8_t>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpPtrs<int8_t>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenHalf:
return InitInvokerFactory<DeviceOpPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenFloat:
return InitInvokerFactory<DeviceOpPtrs<float>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpPtrs<float>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenInt32:
case miopenInt8x4: // Support discontinued.
Expand Down
6 changes: 3 additions & 3 deletions src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,13 @@ ConvSolution ConvHipImplicitGemmGroupFwdXdlops::GetSolution(
switch(problem.GetInDataType())
{
case miopenHalf:
return InitInvokerFactory<DeviceOpGFwdPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpGFwdPtrs<ck::half_t>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenFloat:
return InitInvokerFactory<DeviceOpGFwdPtrs<float>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpGFwdPtrs<float>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenInt8:
return InitInvokerFactory<DeviceOpGFwdPtrs<int8_t>, CKArgs, conv::DataInvokeParams>(
return MakeInvokerFactory<DeviceOpGFwdPtrs<int8_t>, CKArgs, conv::DataInvokeParams>(
problem, config.kernel_id);
case miopenInt32:
case miopenInt8x4: // Support discontinued.
Expand Down

0 comments on commit 09fca66

Please sign in to comment.