!11577 Use ThreadPool in ParallelFor

From: @wuxuejian
Reviewed-by: @kisnwang,@guoqi1024,@c_34
Signed-off-by: @c_34
pull/11577/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 48f2d82c0e

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace kernel {
@ -81,21 +82,22 @@ void CPUKernelUtils::GetElementNumEveryDim(const std::vector<size_t> &shape, std
}
void CPUKernelUtils::ParallelFor(const CTask &task, size_t count) {
auto max_thread_num = std::thread::hardware_concurrency();
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0;
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
std::vector<std::thread> threads;
threads.reserve(thread_num);
std::vector<common::Task> tasks;
size_t start = 0;
size_t once_compute_size = (count + thread_num - 1) / thread_num;
while (start < count) {
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
threads.emplace_back(std::thread(task, start, end));
auto block = [&, start, end]() {
task(start, end);
return common::SUCCESS;
};
tasks.emplace_back(block);
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}
std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) {

Loading…
Cancel
Save