@ -21,7 +21,8 @@ limitations under the License. */
# include <queue>
# include <thread>
# include <vector>
# include "glog/logging.h"
# include "paddle/platform/enforce.h"
# include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
@ -31,7 +32,7 @@ namespace framework {
// number of threads.
class ThreadPool {
public :
typedef std : : packaged_task < void ( ) > Task ;
using Task = std : : packaged_task < std : : unique_ptr < platform : : EnforceNotMet > ( ) > ;
// Returns the singleton of ThreadPool.
static ThreadPool * GetInstance ( ) ;
@ -52,9 +53,28 @@ class ThreadPool {
// std::future::wait().
template < typename Callback >
std : : future < void > Run ( Callback fn ) {
auto f = this - > RunAndGetException ( fn ) ;
return std : : async ( std : : launch : : deferred , ExceptionHandler ( std : : move ( f ) ) ) ;
}
template < typename Callback >
std : : future < std : : unique_ptr < platform : : EnforceNotMet > > RunAndGetException (
Callback fn ) {
std : : unique_lock < std : : mutex > lock ( mutex_ ) ;
Task task ( std : : bind ( fn ) ) ;
std : : future < void > f = task . get_future ( ) ;
Task task ( [ fn ] ( ) - > std : : unique_ptr < platform : : EnforceNotMet > {
try {
fn ( ) ;
return nullptr ;
} catch ( platform : : EnforceNotMet ex ) {
return std : : unique_ptr < platform : : EnforceNotMet > (
new platform : : EnforceNotMet ( ex ) ) ;
} catch ( . . . ) {
LOG ( FATAL )
< < " Unexpected exception is catched in thread pool. All "
" throwable exception in Fluid should be an EnforceNotMet. " ;
}
} ) ;
std : : future < std : : unique_ptr < platform : : EnforceNotMet > > f = task . get_future ( ) ;
tasks_ . push ( std : : move ( task ) ) ;
lock . unlock ( ) ;
scheduled_ . notify_one ( ) ;
@ -65,6 +85,22 @@ class ThreadPool {
void Wait ( ) ;
private :
struct ExceptionHandler {
mutable std : : future < std : : unique_ptr < platform : : EnforceNotMet > > future_ ;
explicit ExceptionHandler (
std : : future < std : : unique_ptr < platform : : EnforceNotMet > > & & f )
: future_ ( std : : move ( f ) ) { }
void operator ( ) ( ) const {
auto ex = this - > future_ . get ( ) ;
if ( ex ! = nullptr ) {
LOG ( FATAL ) < < " The exception is thrown inside the thread pool. You "
" should use RunAndGetException to handle the exception. \n "
" The default exception handler is LOG(FATAL). "
< < ex - > what ( ) ;
}
}
} ;
DISABLE_COPY_AND_ASSIGN ( ThreadPool ) ;
explicit ThreadPool ( int num_threads ) ;