remove useless thread pool code

pull/14571/head
kswang 4 years ago
parent 444ff97206
commit 0909ba2884

@ -24,38 +24,12 @@
namespace mindspore {
namespace common {
#ifdef ENABLE_D
const int kDeviceNum = 8;
const size_t kDeviceNum = 8;
#endif
const int kMaxThreadNum = 23;
bool Queue::Enqueue(Task *task) {
const int tail_index = tail_.load(std::memory_order_relaxed);
// queue full
auto next = (tail_index + 1) % 2;
if (next == head_.load(std::memory_order_acquire)) {
return false;
}
buffer_[tail_index] = task;
tail_.store(next, std::memory_order_release);
++task_size_;
return true;
}
bool Queue::Dequeue(Task **out) {
if (task_size_ == 0) {
return false;
}
// queue empty
const int head_index = head_.load(std::memory_order_relaxed);
if (head_index == tail_.load(std::memory_order_acquire)) {
return false;
}
*out = buffer_[head_index];
head_.store((head_index + 1) % 2, std::memory_order_release);
return true;
}
const size_t kMaxThreadNum = 23;
ThreadPool::ThreadPool() {
int process_core_num = std::thread::hardware_concurrency() - 1;
size_t process_core_num = std::thread::hardware_concurrency() - 1;
if (process_core_num < 1) {
process_core_num = 1;
}
@ -72,80 +46,6 @@ ThreadPool::ThreadPool() {
}
}
bool ThreadPool::SetThreadPool(int config_thread_num) {
if (config_thread_num > max_thread_num_) {
MS_LOG(EXCEPTION) << "Expected thread num is greater than the max thread num, expected thread num="
<< config_thread_num << ", allowed max thread num=" << max_thread_num_;
}
if (config_thread_num > cur_thread_nums_) {
AddNewThread(config_thread_num - cur_thread_nums_);
}
MS_LOG(DEBUG) << "cur_thread_nums_=" << cur_thread_nums_ << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
return true;
}
void ThreadPool::AddNewThread(int add_num) {
for (int i = cur_thread_nums_, j = 0; j < add_num; ++i, ++j) {
auto active = new std::atomic_bool{true};
auto queue = std::make_shared<Queue>();
std::thread thread([this, i, active, queue]() {
Task *task = nullptr;
while (!exit_run_) {
while (*active) {
if (queue->Dequeue(&task)) {
int ret;
try {
ret = (*task)();
} catch (std::exception &e) {
ret = FAIL;
MsException::Instance().SetException();
}
if (ret != SUCCESS) {
error_info_.emplace_back(std::make_pair(i, std::make_pair(false, ret)));
}
queue->task_size_--;
}
std::this_thread::yield();
}
std::unique_lock<std::mutex> queue_lock(thread_mtx_);
queue_ready_.wait(queue_lock, [active, this] { return exit_run_ || *active; });
}
});
thread_list_.emplace_back(std::move(thread));
activate_list_.emplace_back(active);
queue_list_.emplace_back(queue);
}
cur_thread_nums_ += add_num;
cur_thread_run_nums_ += add_num;
MS_LOG(INFO) << "add " << add_num << " thread";
}
void ThreadPool::AddRunThread(int num) {
MS_LOG(DEBUG) << "num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
int active_nums = num - cur_thread_run_nums_;
if (active_nums <= 0 || static_cast<int>(activate_list_.size()) < active_nums) {
return;
}
for (int i = cur_thread_run_nums_ - 1, j = 0; j < active_nums; ++i, ++j) {
*activate_list_[i] = true;
}
std::lock_guard<std::mutex> queueLock(thread_mtx_);
queue_ready_.notify_all();
cur_thread_run_nums_ = num;
}
void ThreadPool::SubRunThread(int num) {
MS_LOG(DEBUG) << "sub num=" << num << ", cur_thread_run_nums_=" << cur_thread_run_nums_;
int deactive_nums = cur_thread_run_nums_ - num;
if (deactive_nums <= 0) {
return;
}
for (int i = num, j = 0; j < deactive_nums; ++i, ++j) {
*activate_list_[i] = false;
}
cur_thread_run_nums_ = num;
}
void ThreadPool::SyncRunLoop() {
while (true) {
Task task;
@ -178,14 +78,14 @@ bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
}
std::unique_lock<std::mutex> lock(pool_mtx_);
exit_run_ = false;
int task_num = tasks.size();
int thread_num = sync_run_threads_.size();
size_t task_num = tasks.size();
size_t thread_num = sync_run_threads_.size();
if (thread_num < max_thread_num_ && thread_num < task_num) {
auto new_thread_num = max_thread_num_;
if (task_num < max_thread_num_) {
new_thread_num = task_num;
}
for (int i = thread_num; i < new_thread_num; ++i) {
for (size_t i = thread_num; i < new_thread_num; ++i) {
sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this));
}
}
@ -203,56 +103,6 @@ bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
return true;
}
bool ThreadPool::InnerSyncRun(const std::vector<Task> &tasks) {
std::lock_guard<std::mutex> sync_run_lock(pool_mtx_);
int thread_num = tasks.size();
if (thread_num > max_thread_num_) {
thread_num = max_thread_num_;
}
if (!SetThreadPool(thread_num)) {
return false;
}
error_info_.clear();
bool succ_flag;
for (int task_id = 0, queue_index = 0; task_id < SizeToInt(tasks.size()); ++task_id) {
do {
succ_flag = true;
if (!queue_list_[queue_index]->Enqueue(const_cast<Task *>(&tasks[task_id]))) {
std::this_thread::yield();
succ_flag = false;
}
} while (!succ_flag);
queue_index++;
if (queue_index >= cur_thread_run_nums_) {
queue_index = queue_index - cur_thread_run_nums_;
}
}
succ_flag = false;
while (!succ_flag) {
std::this_thread::yield();
succ_flag = true;
for (int i = 0; i < cur_thread_run_nums_; ++i) {
if (queue_list_[i]->task_size_ != 0) {
succ_flag = false;
break;
}
}
}
MS_LOG(INFO) << "Finish " << tasks.size() << " task successful";
return CheckResult();
}
bool ThreadPool::CheckResult() {
bool succ_flag = true;
for (auto result : error_info_) {
if (result.second.first) {
MS_LOG(ERROR) << "task " << result.first << " failed, error code is " << result.second.second;
succ_flag = false;
}
}
return succ_flag;
}
ThreadPool &ThreadPool::GetInstance() {
static ThreadPool instance;
return instance;
@ -264,9 +114,6 @@ void ThreadPool::ClearThreadPool() {
return;
}
exit_run_ = true;
cur_thread_run_nums_ = static_cast<int>(thread_list_.size());
SubRunThread(0);
queue_ready_.notify_all();
task_cond_var_.notify_all();
for (auto &it : sync_run_threads_) {
if (it.joinable()) {
@ -274,16 +121,6 @@ void ThreadPool::ClearThreadPool() {
}
}
sync_run_threads_.clear();
for (auto &it : thread_list_) {
if (it.joinable()) {
it.join();
}
}
thread_list_.clear();
for (const auto &it : activate_list_) {
delete it;
}
activate_list_.clear();
}
ThreadPool::~ThreadPool() { ClearThreadPool(); }

@ -32,25 +32,9 @@
namespace mindspore {
namespace common {
const int kCoreThreadNum = 3;
const int kDefaultMaxThreadNum = 8;
enum Status { FAIL = -1, SUCCESS = 0 };
using Task = std::function<int()>;
class Queue {
public:
Queue() = default;
~Queue() = default;
bool Enqueue(Task *task);
bool Dequeue(Task **out);
std::atomic_int task_size_ = {0};
private:
std::atomic_int head_ = {0};
std::atomic_int tail_ = {0};
Task *buffer_[2]{};
};
class ThreadPool {
public:
~ThreadPool();
@ -63,30 +47,15 @@ class ThreadPool {
private:
ThreadPool();
bool SetThreadPool(int config_thread_num);
void AddNewThread(int add_num);
void AddRunThread(int num);
void SubRunThread(int num);
bool CheckResult();
bool InnerSyncRun(const std::vector<Task> &tasks);
void SyncRunLoop();
int cur_thread_nums_{0};
int cur_thread_run_nums_{0};
int core_thread_num_{kCoreThreadNum};
int max_thread_num_{kDefaultMaxThreadNum};
size_t max_thread_num_{1};
std::mutex pool_mtx_;
std::mutex thread_mtx_;
std::condition_variable queue_ready_;
std::atomic_bool exit_run_ = {false};
std::vector<std::atomic_bool *> activate_list_{};
std::vector<std::thread> thread_list_{};
std::vector<std::shared_ptr<Queue>> queue_list_{};
std::vector<std::pair<int, std::pair<bool, int>>> error_info_{};
std::queue<Task> task_queue_;
std::mutex task_mutex_;
std::condition_variable task_cond_var_;
int task_finished_count_{0};
size_t task_finished_count_{0};
std::condition_variable finished_cond_var_;
std::vector<std::thread> sync_run_threads_{};
};

Loading…
Cancel
Save