|
|
|
@ -21,7 +21,8 @@ limitations under the License. */
|
|
|
|
|
#include <queue>
|
|
|
|
|
#include <thread>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "paddle/platform/enforce.h"
|
|
|
|
|
#include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -31,7 +32,7 @@ namespace framework {
|
|
|
|
|
// number of threads.
|
|
|
|
|
class ThreadPool {
|
|
|
|
|
public:
|
|
|
|
|
typedef std::packaged_task<void()> Task;
|
|
|
|
|
using Task = std::packaged_task<std::unique_ptr<platform::EnforceNotMet>()>;
|
|
|
|
|
|
|
|
|
|
// Returns the singleton of ThreadPool.
|
|
|
|
|
static ThreadPool* GetInstance();
|
|
|
|
@ -52,9 +53,28 @@ class ThreadPool {
|
|
|
|
|
// std::future::wait().
|
|
|
|
|
template <typename Callback>
|
|
|
|
|
std::future<void> Run(Callback fn) {
|
|
|
|
|
auto f = this->RunAndGetException(fn);
|
|
|
|
|
return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Callback>
|
|
|
|
|
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
|
|
|
|
|
Callback fn) {
|
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
|
Task task(std::bind(fn));
|
|
|
|
|
std::future<void> f = task.get_future();
|
|
|
|
|
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
|
|
|
|
|
try {
|
|
|
|
|
fn();
|
|
|
|
|
return nullptr;
|
|
|
|
|
} catch (platform::EnforceNotMet ex) {
|
|
|
|
|
return std::unique_ptr<platform::EnforceNotMet>(
|
|
|
|
|
new platform::EnforceNotMet(ex));
|
|
|
|
|
} catch (...) {
|
|
|
|
|
LOG(FATAL)
|
|
|
|
|
<< "Unexpected exception is catched in thread pool. All "
|
|
|
|
|
"throwable exception in Fluid should be an EnforceNotMet.";
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
|
|
|
|
|
tasks_.push(std::move(task));
|
|
|
|
|
lock.unlock();
|
|
|
|
|
scheduled_.notify_one();
|
|
|
|
@ -65,6 +85,22 @@ class ThreadPool {
|
|
|
|
|
void Wait();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
struct ExceptionHandler {
|
|
|
|
|
mutable std::future<std::unique_ptr<platform::EnforceNotMet>> future_;
|
|
|
|
|
explicit ExceptionHandler(
|
|
|
|
|
std::future<std::unique_ptr<platform::EnforceNotMet>>&& f)
|
|
|
|
|
: future_(std::move(f)) {}
|
|
|
|
|
void operator()() const {
|
|
|
|
|
auto ex = this->future_.get();
|
|
|
|
|
if (ex != nullptr) {
|
|
|
|
|
LOG(FATAL) << "The exception is thrown inside the thread pool. You "
|
|
|
|
|
"should use RunAndGetException to handle the exception.\n"
|
|
|
|
|
"The default exception handler is LOG(FATAL)."
|
|
|
|
|
<< ex->what();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(ThreadPool);
|
|
|
|
|
|
|
|
|
|
explicit ThreadPool(int num_threads);
|
|
|
|
|