ThreadPool::Run interface return std::future (#7099)

* Run interface return future

* delete unused comments
del_some_in_makelist
Yancey 7 years ago committed by GitHub
parent 1831176764
commit 5022ee6359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,6 +16,7 @@ limitations under the License. */
#include <condition_variable> #include <condition_variable>
#include <functional> #include <functional>
#include <future>
#include <mutex> #include <mutex>
#include <queue> #include <queue>
#include <thread> #include <thread>
@ -25,10 +26,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef std::function<void()> Task;
class ThreadPool { class ThreadPool {
public: public:
typedef std::packaged_task<void()> Task;
typedef std::function<void()> Fun;
/** /**
* @brief Get a instance of threadpool, the thread number will * @brief Get a instance of threadpool, the thread number will
* be specified as the number of hardware thread contexts * be specified as the number of hardware thread contexts
@ -61,13 +63,18 @@ class ThreadPool {
/** /**
* @brief Push a function to the queue, and will be scheduled and * @brief Push a function to the queue, and will be scheduled and
* executed if a thread is available. * executed if a thread is available.
* @param[in] Task will be pushed to the task queue. * @param[in] Task, will be pushed to the task queue.
* @return std::future<void>, we could wait for the task finished by
* f.wait().
*/ */
void Run(const Task& fn) { std::future<void> Run(const Fun& fn) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
tasks_.push(fn); Task task(std::bind(fn));
std::future<void> f = task.get_future();
tasks_.push(std::move(task));
lock.unlock(); lock.unlock();
scheduled_.notify_one(); scheduled_.notify_one();
return f;
} }
/** /**
@ -110,7 +117,7 @@ class ThreadPool {
break; break;
} }
// pop a task from the task queue // pop a task from the task queue
auto task = tasks_.front(); auto task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
--available_; --available_;

@ -20,16 +20,21 @@ limitations under the License. */
namespace framework = paddle::framework; namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>& sum, int cnt) { void do_sum(framework::ThreadPool* pool, std::atomic<int>& sum, int cnt) {
std::vector<std::future<void>> fs;
for (int i = 0; i < cnt; ++i) { for (int i = 0; i < cnt; ++i) {
pool->Run([&sum]() { sum.fetch_add(1); }); auto f = pool->Run([&sum]() { sum.fetch_add(1); });
fs.push_back(std::move(f));
}
for (auto& f : fs) {
f.wait();
} }
} }
TEST(ThreadPool, ConcurrentInit) { TEST(ThreadPool, ConcurrentInit) {
framework::ThreadPool* pool; framework::ThreadPool* pool;
int concurrent_cnt = 50; int n = 50;
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (int i = 0; i < concurrent_cnt; ++i) { for (int i = 0; i < n; ++i) {
std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); }); std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); });
threads.push_back(std::move(t)); threads.push_back(std::move(t));
} }
@ -38,13 +43,13 @@ TEST(ThreadPool, ConcurrentInit) {
} }
} }
TEST(ThreadPool, ConcurrentStart) { TEST(ThreadPool, ConcurrentRun) {
framework::ThreadPool* pool = framework::ThreadPool::GetInstance(); framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
std::atomic<int> sum(0); std::atomic<int> sum(0);
std::vector<std::thread> threads; std::vector<std::thread> threads;
int concurrent_cnt = 50; int n = 50;
// sum = (n * (n + 1)) / 2 // sum = (n * (n + 1)) / 2
for (int i = 1; i <= concurrent_cnt; ++i) { for (int i = 1; i <= n; ++i) {
std::thread t(do_sum, pool, std::ref(sum), i); std::thread t(do_sum, pool, std::ref(sum), i);
threads.push_back(std::move(t)); threads.push_back(std::move(t));
} }
@ -52,5 +57,5 @@ TEST(ThreadPool, ConcurrentStart) {
t.join(); t.join();
} }
pool->Wait(); pool->Wait();
EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2); EXPECT_EQ(sum, ((n + 1) * n) / 2);
} }

Loading…
Cancel
Save