!4262 Add thread pool

Merge pull request !4262 from huanghui/thread-pool
pull/4262/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d0e29996ec

@ -16,8 +16,8 @@
#include "backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h" #include "backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h"
#include <string> #include <string>
#include <thread>
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -116,18 +116,20 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input
params.out_strides_ = &out_strides_; params.out_strides_ = &out_strides_;
const size_t thread_num = 24; const size_t thread_num = 24;
std::vector<std::thread> threads; std::vector<Task> tasks;
threads.reserve(thread_num);
size_t start = 0; size_t start = 0;
size_t once_compute_size = (num_units_ + thread_num - 1) / thread_num; size_t once_compute_size = (num_units_ + thread_num - 1) / thread_num;
while (start < num_units_) { while (start < num_units_) {
size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size); size_t end = (start + once_compute_size) > num_units_ ? num_units_ : (start + once_compute_size);
threads.emplace_back(std::thread(Compute<T>, &params, start, end)); auto task = [&params, start, end]() -> int {
Compute<T>(&params, start, end);
return SUCCESS;
};
tasks.emplace_back(task);
start += once_compute_size; start += once_compute_size;
} }
for (size_t i = 0; i < threads.size(); ++i) { ThreadPool::GetInstance()->LaunchMultipleTask(tasks);
threads[i].join();
}
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size); auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, inputs[0]->size);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;

@ -3,12 +3,14 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows")
"trans.cc" "trans.cc"
"utils.cc" "utils.cc"
"duplex_pipe_win.cc" "duplex_pipe_win.cc"
"thread_pool.cc"
) )
else() else()
file(GLOB_RECURSE _COMMON_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE _COMMON_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"trans.cc" "trans.cc"
"utils.cc" "utils.cc"
"duplex_pipe.cc" "duplex_pipe.cc"
"thread_pool.cc"
) )
endif() endif()

@ -0,0 +1,199 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/thread_pool.h"
#include <algorithm>
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"
namespace mindspore {
#ifdef ENABLE_D
const int kDeviceNum = 8;
#endif
bool Queue::Enqueue(Task *task) {
const int tail_index = tail_.load(std::memory_order_relaxed);
// queue full
auto next = (tail_index + 1) % 2;
if (next == head_.load(std::memory_order_acquire)) {
return false;
}
buffer_[tail_index] = task;
tail_.store(next, std::memory_order_release);
++task_size_;
return true;
}
bool Queue::Dequeue(Task **out) {
if (task_size_ == 0) {
return false;
}
// queue empty
const int head_index = head_.load(std::memory_order_relaxed);
if (head_index == tail_.load(std::memory_order_acquire)) {
return false;
}
*out = buffer_[head_index];
head_.store((head_index + 1) % 2, std::memory_order_release);
return true;
}
ThreadPool::ThreadPool() {
#ifdef ENABLE_D
auto cpu_core_num = std::thread::hardware_concurrency();
max_thread_num_ = cpu_core_num / kDeviceNum;
#endif
SetThreadPool(core_thread_num_);
}
bool ThreadPool::SetThreadPool(int config_thread_num) {
std::lock_guard<std::mutex> Lock(pool_mtx_);
if (config_thread_num > max_thread_num_) {
MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num="
<< config_thread_num << ", allowed max thread num=" << max_thread_num_;
}
if (config_thread_num > cur_thread_nums_) {
AddNewThread(config_thread_num - cur_thread_nums_);
}
MS_LOG(DEBUG) << "cur_thread_nums_=" << cur_thread_nums_ << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
return true;
}
void ThreadPool::AddNewThread(int add_num) {
for (int i = cur_thread_nums_, j = 0; j < add_num; ++i, ++j) {
auto active = new std::atomic_bool{true};
auto queue = std::make_shared<Queue>();
std::thread thread([this, i, active, queue]() {
Task *task = nullptr;
while (!exit_run_) {
while (*active) {
if (queue->Dequeue(&task)) {
auto ret = (*task)();
if (ret != SUCCESS) {
error_info_.emplace_back(std::make_pair(i, std::make_pair(false, ret)));
}
queue->task_size_--;
}
std::this_thread::yield();
}
std::unique_lock<std::mutex> queue_lock(thread_mtx_);
queue_ready_.wait(queue_lock, [active, this] { return exit_run_ || *active; });
}
});
thread_list_.emplace_back(std::move(thread));
activate_list_.emplace_back(active);
queue_list_.emplace_back(queue);
}
cur_thread_nums_ += add_num;
cur_thread_run_nums_ += add_num;
MS_LOG(INFO) << "add " << add_num << " thread";
}
void ThreadPool::AddRunThread(int num) {
MS_LOG(DEBUG) << "num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
int active_nums = num - cur_thread_run_nums_;
if (active_nums <= 0 || static_cast<int>(activate_list_.size()) < active_nums) {
return;
}
for (int i = cur_thread_run_nums_ - 1, j = 0; j < active_nums; ++i, ++j) {
*activate_list_[i] = true;
}
std::lock_guard<std::mutex> queueLock(thread_mtx_);
queue_ready_.notify_all();
cur_thread_run_nums_ = num;
}
void ThreadPool::SubRunThread(int num) {
MS_LOG(DEBUG) << "sub num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
int deactive_nums = cur_thread_run_nums_ - num;
if (deactive_nums <= 0) {
return;
}
for (int i = num, j = 0; j < deactive_nums; ++i, ++j) {
*activate_list_[i] = false;
}
cur_thread_run_nums_ = num;
}
bool ThreadPool::LaunchMultipleTask(const std::vector<Task> &tasks) {
int thread_num = tasks.size();
if (thread_num > max_thread_num_) {
thread_num = max_thread_num_;
}
if (!SetThreadPool(thread_num)) {
return false;
}
error_info_.clear();
bool succ_flag;
for (int task_id = 0, queue_index = 0; task_id < SizeToInt(tasks.size()); ++task_id) {
do {
succ_flag = true;
if (!queue_list_[queue_index]->Enqueue(const_cast<Task *>(&tasks[task_id]))) {
std::this_thread::yield();
succ_flag = false;
}
} while (!succ_flag);
queue_index++;
if (queue_index >= cur_thread_run_nums_) {
queue_index = queue_index - cur_thread_run_nums_;
}
}
succ_flag = false;
while (!succ_flag) {
std::this_thread::yield();
succ_flag = true;
for (int i = 0; i < cur_thread_run_nums_; ++i) {
if (queue_list_[i]->task_size_ != 0) {
succ_flag = false;
break;
}
}
}
MS_LOG(INFO) << "Finish " << tasks.size() << " task successful";
return CheckResult();
}
bool ThreadPool::CheckResult() {
bool succ_flag = true;
for (auto result : error_info_) {
if (result.second.first) {
MS_LOG(ERROR) << "task " << result.first << " failed, error code is " << result.second.second;
succ_flag = false;
}
}
return succ_flag;
}
ThreadPool *ThreadPool::GetInstance() {
static ThreadPool instance;
return &instance;
}
ThreadPool::~ThreadPool() {
cur_thread_run_nums_ = static_cast<int>(thread_list_.size());
exit_run_ = true;
SubRunThread(0);
queue_ready_.notify_all();
for (auto &it : thread_list_) {
if (it.joinable()) {
it.join();
}
}
for (const auto &it : activate_list_) {
delete it;
}
}
} // namespace mindspore

@ -0,0 +1,86 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_COMMON_THREAD_POOL_H_
#define MINDSPORE_CCSRC_COMMON_THREAD_POOL_H_
#include <mutex>
#include <condition_variable>
#include <thread>
#include <vector>
#include <queue>
#include <string>
#include <atomic>
#include <memory>
#include <utility>
#include <functional>
#include <iostream>
#include "utils/log_adapter.h"
namespace mindspore {
const int kCoreThreadNum = 3;
const int kDefaultMaxThreadNum = 8;
enum Status { FAIL = -1, SUCCESS = 0 };
using Task = std::function<int()>;
class Queue {
public:
Queue() = default;
~Queue() = default;
bool Enqueue(Task *task);
bool Dequeue(Task **out);
std::atomic_int task_size_ = {0};
private:
std::atomic_int head_ = {0};
std::atomic_int tail_ = {0};
Task *buffer_[2]{};
};
class ThreadPool {
public:
~ThreadPool();
ThreadPool(const ThreadPool &) = delete;
ThreadPool &operator=(const ThreadPool &) = delete;
static ThreadPool *GetInstance();
// Use the tasks' size of threads to execute these tasks, one thread execute one task.
bool LaunchMultipleTask(const std::vector<Task> &tasks);
private:
ThreadPool();
bool SetThreadPool(int config_thread_num);
void AddNewThread(int add_num);
void AddRunThread(int num);
void SubRunThread(int num);
bool CheckResult();
int cur_thread_nums_{0};
int cur_thread_run_nums_{0};
int core_thread_num_{kCoreThreadNum};
int max_thread_num_{kDefaultMaxThreadNum};
std::mutex pool_mtx_;
std::mutex thread_mtx_;
std::condition_variable queue_ready_;
std::atomic_bool exit_run_ = {false};
std::vector<std::atomic_bool *> activate_list_{};
std::vector<std::thread> thread_list_{};
std::vector<std::shared_ptr<Queue>> queue_list_{};
std::vector<std::pair<int, std::pair<bool, int>>> error_info_{};
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_COMMON_THREAD_POOL_H_

@ -23,7 +23,7 @@ from mindspore import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU', save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target='CPU', save_graphs=False)
class ScatterNdUpdate1(nn.Cell): class ScatterNdUpdate1(nn.Cell):

Loading…
Cancel
Save