@ -48,6 +48,13 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform : : MKLDNNHandler : : AppendKey ( & key , std : : to_string ( dt ) ) ;
platform : : MKLDNNHandler : : AppendKey ( & key , std : : to_string ( fmt ) ) ;
platform : : MKLDNNHandler : : AppendKey ( & key , suffix ) ;
if ( platform : : get_cur_thread_id ( ) ! = - 1 ) {
auto tid = std : : this_thread : : get_id ( ) ;
std : : stringstream ss ;
ss < < tid ;
platform : : MKLDNNHandler : : AppendKey ( & key , " -t: " ) ;
platform : : MKLDNNHandler : : AppendKey ( & key , ss . str ( ) ) ;
}
return key ;
}
@ -128,6 +135,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std : : string key_pool_workspace_memory =
key + " @pool_workspace_memory " ;
std : : shared_ptr < mkldnn : : memory > src_memory , dst_memory ;
std : : shared_ptr < mkldnn : : pooling_forward : : primitive_desc > pool_pd ;
std : : shared_ptr < mkldnn : : memory > pool_src_memory_p , pool_dst_memory_p ;
auto pool_p =
std : : static_pointer_cast < pooling_forward > ( dev_ctx . GetBlob ( key_pool_p ) ) ;
if ( pool_p = = nullptr ) {
@ -158,9 +169,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_pd into global device context to be referred in backward path
if ( ! is_test ) dev_ctx . SetBlob ( key_pool_pd , pool_pd ) ;
auto src_memory = std : : make_shared < memory > ( pool_pd - > src_primitive_desc ( ) ,
to_void_cast < T > ( input_data ) ) ;
auto dst_memory =
src_memory = std : : make_shared < memory > ( pool_pd - > src_primitive_desc ( ) ,
to_void_cast < T > ( input_data ) ) ;
dst_memory =
std : : make_shared < memory > ( pool_pd - > dst_primitive_desc ( ) , output_data ) ;
dev_ctx . SetBlob ( key_pool_src_mem_p , src_memory ) ;
@ -186,11 +197,11 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
( memory : : format ) dst_memory - > get_primitive_desc ( ) . desc ( ) . data . format ;
} else {
// Primitives already exist
auto pool_src_memory_p =
pool_src_memory_p =
std : : static_pointer_cast < memory > ( dev_ctx . GetBlob ( key_pool_src_mem_p ) ) ;
PADDLE_ENFORCE ( pool_src_memory_p ! = nullptr ,
" Fail to find pooling src mem_p in device context " ) ;
auto pool_dst_memory_p =
pool_dst_memory_p =
std : : static_pointer_cast < memory > ( dev_ctx . GetBlob ( key_pool_dst_mem_p ) ) ;
PADDLE_ENFORCE ( pool_dst_memory_p ! = nullptr ,
" Fail to find pooling dst mem_p in device context " ) ;