Skip to content

Commit

Permalink
Merge branch 'develop' into sl/testpackage_nogpu
Browse files Browse the repository at this point in the history
  • Loading branch information
xinlipn authored Apr 25, 2024
2 parents 2db8ccd + 2232184 commit c0e672c
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ static void ShrinkToFind10Results(std::vector<PerfField>& found)

static inline std::vector<PerfField> FindConvolution(const ExecutionContext& ctx,
const conv::ProblemDescription& problem,
const AnyInvokeParams& invoke_ctx)
const AnyInvokeParams& invoke_ctx,
const int requestAlgoCount)
{
auto results = std::vector<PerfField>{};
auto sol = boost::optional<miopenConvSolution_t>{};
Expand All @@ -214,7 +215,16 @@ static inline std::vector<PerfField> FindConvolution(const ExecutionContext& ctx
if(findMode.IsFast(ctx) || findMode.IsHybrid(ctx))
{
auto fallback = bool{};
auto sols = conv.GetSolutions(ctx, problem, 1, &fallback);
auto sols = conv.GetSolutions(ctx, problem, requestAlgoCount, &fallback);

// Remove solutions for which the given workspace size is insufficient
sols.erase(std::remove_if(sols.begin(),
sols.end(),
[&](const miopenConvSolution_t& entry) {
return invoke_ctx.GetWorkspaceSize() < entry.workspace_size;
}),
sols.end());

// override the normal find with immed mode with env var
if(!sols.empty() && (!(findMode.IsHybrid(ctx) && fallback) ||
miopen::IsEnabled(ENV(MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK))))
Expand Down Expand Up @@ -303,7 +313,7 @@ void ConvolutionDescriptor::FindConvFwdAlgorithm(Handle& handle,
workSpaceSize,
attribute.gfx90aFp16alt.GetFwd()};

const auto results = FindConvolution(ctx, problem, invoke_ctx);
const auto results = FindConvolution(ctx, problem, invoke_ctx, requestAlgoCount);

if(results.empty())
{
Expand Down Expand Up @@ -891,7 +901,7 @@ void ConvolutionDescriptor::FindConvBwdDataAlgorithm(Handle& handle,
workSpaceSize,
this->attribute.gfx90aFp16alt.GetBwd()};

const auto results = FindConvolution(ctx, problem, invoke_ctx);
const auto results = FindConvolution(ctx, problem, invoke_ctx, requestAlgoCount);

if(results.empty())
{
Expand Down Expand Up @@ -1102,7 +1112,7 @@ void ConvolutionDescriptor::FindConvBwdWeightsAlgorithm(Handle& handle,
workSpaceSize,
attribute.gfx90aFp16alt.GetWrW()};

const auto results = FindConvolution(ctx, problem, invoke_ctx);
const auto results = FindConvolution(ctx, problem, invoke_ctx, requestAlgoCount);

if(results.empty())
{
Expand Down

0 comments on commit c0e672c

Please sign in to comment.