Skip to content

Commit

Permalink
[OpenCL] Fix cl_context gen_lwss bugs (#6734)
Browse files Browse the repository at this point in the history
* fix cl_context gen_lwss bugs test=develop

* test=develop

* test=develop
  • Loading branch information
zhenlin-work authored Aug 23, 2021
1 parent 2127185 commit b60d222
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
10 changes: 7 additions & 3 deletions lite/backends/opencl/cl_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ cl::NDRange CLContext::DefaultGlobalWorkSize(const CLImage &image) {
std::set<cl::NDRange, CLContext::CompareByRange>
CLContext::GenerateLocalWorkSizes(cl::NDRange gws, size_t max_ws) {
size_t tune_type = CLRuntime::Global()->auto_tune();
auto first_lws = DefaultLocalWorkSize(gws, max_ws, 3, false);
auto first_lws = DefaultLocalWorkSize(gws, max_ws, tune_type, 3, false);
std::set<cl::NDRange, CompareByRange> lwss;
for (auto one_lws : first_lws) {
lwss.insert(one_lws);
Expand All @@ -139,7 +139,7 @@ CLContext::GenerateLocalWorkSizes(cl::NDRange gws, size_t max_ws) {
for (bool tune_reverse : tune_reverses) {
for (size_t divisor : divisors) {
std::set<cl::NDRange, CompareByRange> tmp_lws =
DefaultLocalWorkSize(gws, max_ws, divisor, tune_reverse);
DefaultLocalWorkSize(gws, max_ws, tune_type, divisor, tune_reverse);
for (cl::NDRange one_lws : tmp_lws) {
lwss.insert(one_lws);
}
Expand Down Expand Up @@ -230,6 +230,7 @@ CLContext::GenerateLocalWorkSizes(cl::NDRange gws, size_t max_ws) {
std::set<cl::NDRange, CLContext::CompareByRange>
CLContext::DefaultLocalWorkSize(const cl::NDRange &gws,
register size_t max_ws,
size_t tune_type /*=0*/,
const int &divisor /*=2*/,
const bool &reverse /*=false*/,
const size_t &user_def_max_ws /*=0*/) {
Expand Down Expand Up @@ -266,7 +267,10 @@ CLContext::DefaultLocalWorkSize(const cl::NDRange &gws,
}
ly_src = (ly_src & 0x01) ? 1 : ly_src >> 1;
} while (ly_src > 1);

if (tune_type == lite_api::CL_TUNE_NONE && lws_set.empty()) {
lws_set.insert(
(reverse ? cl::NDRange{lz, ly, lx} : cl::NDRange{lx, ly, lz}));
}
return lws_set;
}

Expand Down
1 change: 1 addition & 0 deletions lite/backends/opencl/cl_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class CLContext {
std::set<cl::NDRange, CompareByRange> DefaultLocalWorkSize(
const cl::NDRange &global_work_size,
register size_t max_work_size,
size_t tune_type = 0,
const int &divitor = 2,
const bool &tune_reverse = false,
const size_t &user_defined_max_work_size = 0);
Expand Down

0 comments on commit b60d222

Please sign in to comment.