From 0c7303de954ae7745490562693998f3bfbab9f02 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Feb 2021 09:54:37 +0900 Subject: [PATCH 1/2] Fast path for single thread run to allow app level threading --- src/runtime/thread_pool.cc | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index ba14c733176e..b236b7f48b16 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -363,21 +363,28 @@ TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRe } // namespace tvm int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { + int num_workers = tvm::runtime::threading::MaxConcurrency(); + if (num_workers == 1) { + TVMParallelGroupEnv env; + env.num_task = 1; + (*flambda)(1, &env, cdata); + return 0; + } else { #if !TVM_THREADPOOL_USE_OPENMP - int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1); - return res; + int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1); + return res; #else - int num_workers = tvm::runtime::threading::MaxConcurrency(); - if (num_task == 0) num_task = num_workers; - omp_set_num_threads(num_task); + if (num_task == 0) num_task = num_workers; + omp_set_num_threads(num_task); #pragma omp parallel num_threads(num_task) - { - TVMParallelGroupEnv env; - env.num_task = num_task; - (*flambda)(omp_get_thread_num(), &env, cdata); - } - return 0; + { + TVMParallelGroupEnv env; + env.num_task = num_task; + (*flambda)(omp_get_thread_num(), &env, cdata); + } + return 0; #endif + } } int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { From 18fdf41150511f6562ba8904b5519c5950bc0e93 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Feb 2021 12:22:19 +0900 Subject: [PATCH 2/2] add sync counter to avoid error in one of tests --- src/runtime/thread_pool.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index b236b7f48b16..5f5a811c2d30 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -365,9 +365,11 @@ TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRe int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { int num_workers = tvm::runtime::threading::MaxConcurrency(); if (num_workers == 1) { + std::atomic sync_counter{0}; TVMParallelGroupEnv env; env.num_task = 1; - (*flambda)(1, &env, cdata); + env.sync_handle = &sync_counter; + (*flambda)(0, &env, cdata); return 0; } else { #if !TVM_THREADPOOL_USE_OPENMP