@ -32,23 +32,25 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
" 'Place' is not supported, Please re-compile with WITH_GPU "
" 'Place' is not supported, Please re-compile with WITH_GPU "
" option " ) ;
" option " ) ;
}
}
return it - > second . get ( ) ;
return it - > second . get ( ) .get ( ) ;
}
}
const std : : vector < const DeviceContext * >
template < typename DevCtx , typename PlaceType >
DeviceContextPool : : GetAllDeviceContexts ( ) const {
inline void EmplaceDeviceContext (
std : : vector < const DeviceContext * > all_device_ctx ;
std : : map < Place , std : : shared_future < std : : unique_ptr < DeviceContext > > > *
all_device_ctx . reserve ( device_contexts_ . size ( ) ) ;
map_ptr ,
for ( auto & dev_ctx : device_contexts_ ) {
platform : : Place p ) {
all_device_ctx . emplace_back ( dev_ctx . second . get ( ) ) ;
using PtrType = std : : unique_ptr < DeviceContext > ;
}
map_ptr - > emplace ( p , std : : async ( std : : launch : : deferred , [ = ] {
return all_device_ctx ;
// lazy evaluation. i.e., only create device context at
// first `Get`
return PtrType ( new DevCtx ( boost : : get < PlaceType > ( p ) ) ) ;
} ) ) ;
}
}
DeviceContextPool : : DeviceContextPool (
DeviceContextPool : : DeviceContextPool (
const std : : vector < platform : : Place > & places ) {
const std : : vector < platform : : Place > & places ) {
PADDLE_ENFORCE_GT ( places . size ( ) , 0 ) ;
PADDLE_ENFORCE_GT ( places . size ( ) , 0 ) ;
using PtrType = std : : unique_ptr < DeviceContext > ;
std : : set < Place > set ;
std : : set < Place > set ;
for ( auto & p : places ) {
for ( auto & p : places ) {
set . insert ( p ) ;
set . insert ( p ) ;
@ -57,16 +59,13 @@ DeviceContextPool::DeviceContextPool(
for ( auto & p : set ) {
for ( auto & p : set ) {
if ( platform : : is_cpu_place ( p ) ) {
if ( platform : : is_cpu_place ( p ) ) {
# ifdef PADDLE_WITH_MKLDNN
# ifdef PADDLE_WITH_MKLDNN
device_contexts_ . emplace (
EmplaceDeviceContext < MKLDNNDeviceContext , CPUPlace > ( & device_contexts_ , p ) ;
p , PtrType ( new MKLDNNDeviceContext ( boost : : get < CPUPlace > ( p ) ) ) ) ;
# else
# else
device_contexts_ . emplace (
EmplaceDeviceContext < CPUDeviceContext , CPUPlace > ( & device_contexts_ , p ) ;
p , PtrType ( new CPUDeviceContext ( boost : : get < CPUPlace > ( p ) ) ) ) ;
# endif
# endif
} else if ( platform : : is_gpu_place ( p ) ) {
} else if ( platform : : is_gpu_place ( p ) ) {
# ifdef PADDLE_WITH_CUDA
# ifdef PADDLE_WITH_CUDA
device_contexts_ . emplace (
EmplaceDeviceContext < CUDADeviceContext , CUDAPlace > ( & device_contexts_ , p ) ;
p , PtrType ( new CUDADeviceContext ( boost : : get < CUDAPlace > ( p ) ) ) ) ;
# else
# else
PADDLE_THROW (
PADDLE_THROW (
" 'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
" 'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
@ -74,9 +73,8 @@ DeviceContextPool::DeviceContextPool(
# endif
# endif
} else if ( platform : : is_cuda_pinned_place ( p ) ) {
} else if ( platform : : is_cuda_pinned_place ( p ) ) {
# ifdef PADDLE_WITH_CUDA
# ifdef PADDLE_WITH_CUDA
device_contexts_ . emplace (
EmplaceDeviceContext < CUDAPinnedDeviceContext , CUDAPinnedPlace > (
p ,
& device_contexts_ , p ) ;
PtrType ( new CUDAPinnedDeviceContext ( boost : : get < CUDAPinnedPlace > ( p ) ) ) ) ;
# else
# else
PADDLE_THROW (
PADDLE_THROW (
" 'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
" 'CUDAPlace' is not supported, Please re-compile with WITH_GPU "