diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index ba14c733176e0..5f5a811c2d30e 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -363,21 +363,30 @@ TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRe } // namespace tvm int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { + int num_workers = tvm::runtime::threading::MaxConcurrency(); + if (num_workers == 1) { + std::atomic sync_counter{0}; + TVMParallelGroupEnv env; + env.num_task = 1; + env.sync_handle = &sync_counter; + (*flambda)(0, &env, cdata); + return 0; + } else { #if !TVM_THREADPOOL_USE_OPENMP - int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1); - return res; + int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1); + return res; #else - int num_workers = tvm::runtime::threading::MaxConcurrency(); - if (num_task == 0) num_task = num_workers; - omp_set_num_threads(num_task); + if (num_task == 0) num_task = num_workers; + omp_set_num_threads(num_task); #pragma omp parallel num_threads(num_task) - { - TVMParallelGroupEnv env; - env.num_task = num_task; - (*flambda)(omp_get_thread_num(), &env, cdata); - } - return 0; + { + TVMParallelGroupEnv env; + env.num_task = num_task; + (*flambda)(omp_get_thread_num(), &env, cdata); + } + return 0; #endif + } } int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) {