@ -337,27 +337,26 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std : : string key_activation_pd = key_common_ + " @activation_pd " ;
activation_pd_ =
std : : static_pointer_cast < mkldnn : : eltwise_forward : : primitive_desc > (
dev_ctx_ . GetBlob ( key_activation_pd ) ) ;
if ( activation_pd_ = = nullptr ) {
fwd_pd_ = std : : static_pointer_cast < mkldnn : : eltwise_forward : : primitive_desc > (
dev_ctx_ . GetBlob ( key_activation_pd ) ) ;
if ( fwd_pd_ = = nullptr ) {
static std : : mutex acquire_barrier ;
std : : lock_guard < std : : mutex > block_threads_until_finish_this_job (
acquire_barrier ) ;
activation _pd_ =
fwd _pd_ =
std : : static_pointer_cast < mkldnn : : eltwise_forward : : primitive_desc > (
dev_ctx_ . GetBlob ( key_activation_pd ) ) ;
if ( activation _pd_ = = nullptr ) {
if ( fwd _pd_ = = nullptr ) {
auto activation_desc = mkldnn : : eltwise_forward : : desc (
prop_kind , algorithm , md , alpha , beta ) ;
activation _pd_. reset ( new mkldnn : : eltwise_forward : : primitive_desc (
fwd _pd_. reset ( new mkldnn : : eltwise_forward : : primitive_desc (
activation_desc , engine_ ) ) ;
dev_ctx_ . SetBlob ( key_activation_pd , activation _pd_) ;
dev_ctx_ . SetBlob ( key_activation_pd , fwd _pd_) ;
}
}
return activation _pd_;
return fwd _pd_;
}
std : : shared_ptr < mkldnn : : eltwise_backward : : primitive_desc >
@ -366,23 +365,22 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
const mkldnn : : memory : : desc & src_md , float alpha , float beta ) {
const std : : string key_activation_pd = key_common_ + " @activation_pd " ;
const std : : string key_activation_bwd_pd = key_ + " @activation_bwd_pd " ;
activation_ bwd_pd_ =
bwd_pd_ =
std : : static_pointer_cast < mkldnn : : eltwise_backward : : primitive_desc > (
dev_ctx_ . GetBlob ( key_activation_bwd_pd ) ) ;
if ( activation_ bwd_pd_ = = nullptr ) {
activation _pd_ =
if ( bwd_pd_ = = nullptr ) {
fwd _pd_ =
std : : static_pointer_cast < mkldnn : : eltwise_forward : : primitive_desc > (
dev_ctx_ . GetBlob ( key_activation_pd ) ) ;
// PD from FWD op has to exist.
PADDLE_ENFORCE ( activation_pd_ ! = nullptr ,
" Eltwise MKL-DNN not found in cache! " ) ;
PADDLE_ENFORCE_NOT_NULL ( fwd_pd_ , " Eltwise MKL-DNN not found in cache! " ) ;
auto backward_desc = mkldnn : : eltwise_backward : : desc (
algorithm , diff_dst_md , src_md , alpha , beta ) ;
activation_ bwd_pd_. reset ( new mkldnn : : eltwise_backward : : primitive_desc (
backward_desc , engine_ , * activation _pd_) ) ;
dev_ctx_ . SetBlob ( key_activation_bwd_pd , activation_ bwd_pd_) ;
bwd_pd_. reset ( new mkldnn : : eltwise_backward : : primitive_desc (
backward_desc , engine_ , * fwd _pd_) ) ;
dev_ctx_ . SetBlob ( key_activation_bwd_pd , bwd_pd_) ;
}
return activation_ bwd_pd_;
return bwd_pd_;
}
std : : shared_ptr < mkldnn : : eltwise_forward > AcquireActivation (
@ -395,22 +393,25 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
dev_ctx_ . GetBlob ( prim_key ) ) ;
if ( eltwise_p = = nullptr ) {
eltwise_p = std : : make_shared < mkldnn : : eltwise_forward > (
* activation _pd_, * ( src_memory_p ) , * ( dst_memory_p ) ) ;
* fwd _pd_, * ( src_memory_p ) , * ( dst_memory_p ) ) ;
dev_ctx_ . SetBlob ( prim_key , eltwise_p ) ;
}
return eltwise_p ;
}
// TODO(jczaja): Merge all AcquireDstMemoryFromPrimitive into one
std : : shared_ptr < mkldnn : : memory > AcquireDstMemoryFromPrimitive ( void * ptr ) {
return this - > AcquireMemoryFromPrimitive (
activation_pd_ - > dst_primitive_desc ( ) , ptr , " @dst_mem_p " ) ;
template < typename T >
std : : shared_ptr < mkldnn : : memory > AcquireDstMemoryFromPrimitive (
framework : : Tensor * output , platform : : Place place ) {
T * ptr = output - > mutable_data < T > ( place ,
fwd_pd_ - > dst_primitive_desc ( ) . get_size ( ) ) ;
return this - > AcquireMemoryFromPrimitive ( fwd_pd_ - > dst_primitive_desc ( ) , ptr ,
" @dst_mem_p " ) ;
}
std : : shared_ptr < mkldnn : : memory > AcquireDiffSrcMemoryFromPrimitive ( void * ptr ) {
return this - > AcquireMemoryFromPrimitive (
activation_bwd_pd_ - > diff_src_primitive_desc ( ) , ptr , " @diff_src_mem_p " ) ;
return this - > AcquireMemoryFromPrimitive ( bwd_pd_ - > diff_src_primitive_desc ( ) ,
ptr , " @diff_src_mem_p " ) ;
}
std : : shared_ptr < mkldnn : : eltwise_backward > AcquireActivationBackward (
@ -424,7 +425,7 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
dev_ctx_ . GetBlob ( prim_key ) ) ;
if ( eltwise_bwd_p = = nullptr ) {
eltwise_bwd_p = std : : make_shared < mkldnn : : eltwise_backward > (
* activation_ bwd_pd_, * ( src_memory_p ) , * ( diff_dst_memory_p ) ,
* bwd_pd_, * ( src_memory_p ) , * ( diff_dst_memory_p ) ,
* ( diff_src_memory_p ) ) ;
dev_ctx_ . SetBlob ( prim_key , eltwise_bwd_p ) ;
}
@ -449,8 +450,8 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
}
private :
std : : shared_ptr < mkldnn : : eltwise_forward : : primitive_desc > activation _pd_;
std : : shared_ptr < mkldnn : : eltwise_backward : : primitive_desc > activation_ bwd_pd_;
std : : shared_ptr < mkldnn : : eltwise_forward : : primitive_desc > fwd _pd_;
std : : shared_ptr < mkldnn : : eltwise_backward : : primitive_desc > bwd_pd_;
} ;
class LRNMKLDNNHandler : public MKLDNNHandler {