@ -32,23 +32,25 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
" 'Place' is not supported, Please re-compile with WITH_GPU "
" option " ) ;
}
return it - > second . get ( ) ;
return it - > second . get ( ) .get ( ) ;
}
const std : : vector < const DeviceContext * >
DeviceContextPool : : GetAllDeviceContexts ( ) const {
std : : vector < const DeviceContext * > all_device_ctx ;
all_device_ctx . reserve ( device_contexts_ . size ( ) ) ;
for ( auto & dev_ctx : device_contexts_ ) {
all_device_ctx . emplace_back ( dev_ctx . second . get ( ) ) ;
}
return all_device_ctx ;
template < typename DevCtx , typename PlaceType >
inline void EmplaceDeviceContext (
std : : map < Place , std : : shared_future < std : : unique_ptr < DeviceContext > > > *
map_ptr ,
platform : : Place p ) {
using PtrType = std : : unique_ptr < DeviceContext > ;
map_ptr - > emplace ( p , std : : async ( std : : launch : : deferred , [ = ] {
// lazy evaluation. i.e., only create device context at
// first `Get`
return PtrType ( new DevCtx ( boost : : get < PlaceType > ( p ) ) ) ;
} ) ) ;
}
DeviceContextPool : : DeviceContextPool (
const std : : vector < platform : : Place > & places ) {
PADDLE_ENFORCE_GT ( places . size ( ) , 0 ) ;
using PtrType = std : : unique_ptr < DeviceContext > ;
std : : set < Place > set ;
for ( auto & p : places ) {
set . insert ( p ) ;
@ -57,16 +59,13 @@ DeviceContextPool::DeviceContextPool(
for ( auto & p : set ) {
if ( platform : : is_cpu_place ( p ) ) {
# ifdef PADDLE_WITH_MKLDNN
device_contexts_ . emplace (
p , PtrType ( new MKLDNNDeviceContext ( boost : : get < CPUPlace > ( p ) ) ) ) ;
EmplaceDeviceContext < MKLDNNDeviceContext , CPUPlace > ( & device_contexts_ , p ) ;
# else
device_contexts_ . emplace (
p , PtrType ( new CPUDeviceContext ( boost : : get < CPUPlace > ( p ) ) ) ) ;
EmplaceDeviceContext < CPUDeviceContext , CPUPlace > ( & device_contexts_ , p ) ;
# endif
} else if ( platform : : is_gpu_place ( p ) ) {
# ifdef PADDLE_WITH_CUDA
device_contexts_ . emplace (
p , PtrType ( new CUDADeviceContext ( boost : : get < CUDAPlace > ( p ) ) ) ) ;
EmplaceDeviceContext < CUDADeviceContext , CUDAPlace > ( & device_contexts_ , p ) ;
# else
PADDLE_THROW (
" 'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
@ -74,9 +73,8 @@ DeviceContextPool::DeviceContextPool(
# endif
} else if ( platform : : is_cuda_pinned_place ( p ) ) {
# ifdef PADDLE_WITH_CUDA
device_contexts_ . emplace (
p ,
PtrType ( new CUDAPinnedDeviceContext ( boost : : get < CUDAPinnedPlace > ( p ) ) ) ) ;
EmplaceDeviceContext < CUDAPinnedDeviceContext , CUDAPinnedPlace > (
& device_contexts_ , p ) ;
# else
PADDLE_THROW (
" 'CUDAPlace' is not supported, Please re-compile with WITH_GPU "