From 171a7784e1ab1f02fb1cb7c63fba037724618192 Mon Sep 17 00:00:00 2001 From: wuxuejian Date: Sat, 23 Jan 2021 17:50:55 +0800 Subject: [PATCH] Use ThreadPool in ParallelFor --- .../backend/kernel_compiler/cpu/cpu_kernel.cc | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc index baf1c9e30a..38acb68929 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "common/thread_pool.h" namespace mindspore { namespace kernel { @@ -81,21 +82,22 @@ void CPUKernelUtils::GetElementNumEveryDim(const std::vector &shape, std } void CPUKernelUtils::ParallelFor(const CTask &task, size_t count) { - auto max_thread_num = std::thread::hardware_concurrency(); + auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); const float block_size = 128.0; size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num; - std::vector threads; - threads.reserve(thread_num); + std::vector tasks; size_t start = 0; size_t once_compute_size = (count + thread_num - 1) / thread_num; while (start < count) { size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size); - threads.emplace_back(std::thread(task, start, end)); + auto block = [&, start, end]() { + task(start, end); + return common::SUCCESS; + }; + tasks.emplace_back(block); start += once_compute_size; } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } + common::ThreadPool::GetInstance().SyncRun(tasks); } std::vector CPUKernelUtils::FlatShapeByAxis(const std::vector &shape, int axis) {