diff --git a/src/ocl/convolutionocl.cpp b/src/ocl/convolutionocl.cpp index 434e8cbf36..2c54e26c67 100644 --- a/src/ocl/convolutionocl.cpp +++ b/src/ocl/convolutionocl.cpp @@ -204,7 +204,8 @@ static void ShrinkToFind10Results(std::vector& found) static inline std::vector FindConvolution(const ExecutionContext& ctx, const conv::ProblemDescription& problem, - const AnyInvokeParams& invoke_ctx) + const AnyInvokeParams& invoke_ctx, + const int requestAlgoCount) { auto results = std::vector{}; auto sol = boost::optional{}; @@ -214,7 +215,16 @@ static inline std::vector FindConvolution(const ExecutionContext& ctx if(findMode.IsFast(ctx) || findMode.IsHybrid(ctx)) { auto fallback = bool{}; - auto sols = conv.GetSolutions(ctx, problem, 1, &fallback); + auto sols = conv.GetSolutions(ctx, problem, requestAlgoCount, &fallback); + + // Remove solutions for which the given workspace size is insufficient + sols.erase(std::remove_if(sols.begin(), + sols.end(), + [&](const miopenConvSolution_t& entry) { + return invoke_ctx.GetWorkspaceSize() < entry.workspace_size; + }), + sols.end()); + // override the normal find with immed mode with env var if(!sols.empty() && (!(findMode.IsHybrid(ctx) && fallback) || miopen::IsEnabled(ENV(MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK)))) @@ -303,7 +313,7 @@ void ConvolutionDescriptor::FindConvFwdAlgorithm(Handle& handle, workSpaceSize, attribute.gfx90aFp16alt.GetFwd()}; - const auto results = FindConvolution(ctx, problem, invoke_ctx); + const auto results = FindConvolution(ctx, problem, invoke_ctx, requestAlgoCount); if(results.empty()) { @@ -891,7 +901,7 @@ void ConvolutionDescriptor::FindConvBwdDataAlgorithm(Handle& handle, workSpaceSize, this->attribute.gfx90aFp16alt.GetBwd()}; - const auto results = FindConvolution(ctx, problem, invoke_ctx); + const auto results = FindConvolution(ctx, problem, invoke_ctx, requestAlgoCount); if(results.empty()) { @@ -1102,7 +1112,7 @@ void ConvolutionDescriptor::FindConvBwdWeightsAlgorithm(Handle& handle, workSpaceSize, attribute.gfx90aFp16alt.GetWrW()}; - const auto results = FindConvolution(ctx, problem, invoke_ctx); + const auto results = FindConvolution(ctx, problem, invoke_ctx, requestAlgoCount); if(results.empty()) {