@ -103,24 +103,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
bool is_test = ctx . Attr < bool > ( " is_test " ) ;
std : : string key = platform : : MKLDNNHandler : : GetHash (
src_tz , std : : to_string ( algorithm ) + std : : to_string ( alpha ) +
std : : to_string ( beta ) + ctx . op ( ) . Input ( " X " ) ) ;
// TODO(jczaja): Make it Thread safe
// save input data and layout to be referred in backward path
const std : : string key_src_data = key + " @eltwise_fwd_src_data " ;
const std : : string key_src_layout = key + " @eltwise_fwd_src_layout " ;
// Just in case some int8 models are run interchangebly
// with float models then format maybe diffrent
key + = std : : to_string ( src_format ) ;
const std : : string key_src_mem = key + " @eltwise_fwd_src_mem " ;
auto p_src_data = std : : make_shared < const T * > ( x_data ) ;
auto p_src_layout = std : : make_shared < memory : : format > ( src_format ) ;
if ( ! is_test ) {
dev_ctx . SetBlob ( key_src_data , p_src_data ) ;
dev_ctx . SetBlob ( key_src_layout , p_src_layout ) ;
}
std : : string key = platform : : ActivationMKLDNNHandler : : GetHash (
src_tz , algorithm , src_format , alpha , beta , ctx . op ( ) . Input ( " X " ) ) ;
platform : : ActivationMKLDNNHandler handler ( dev_ctx , mkldnn_engine , key ) ;
@ -133,11 +117,6 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
algorithm , md , alpha , beta ) ;
auto src_memory_p = handler . AcquireSrcMemory ( md , to_void_cast < T > ( x_data ) ) ;
// jczaja: Workaround, src_memory_p is needed in BWD so it has
// to be accessible under key not dependant on TID
if ( ! is_test ) {
dev_ctx . SetBlob ( key_src_mem , src_memory_p ) ;
}
auto dst_memory_p =
handler . AcquireDstMemoryFromPrimitive ( to_void_cast < T > ( y_data ) ) ;
@ -158,6 +137,9 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
auto & dev_ctx = ctx . template device_context < MKLDNNDeviceContext > ( ) ;
const auto & mkldnn_engine = dev_ctx . GetEngine ( ) ;
const auto * x = ctx . Input < Tensor > ( " X " ) ;
const T * x_data = x - > data < T > ( ) ;
const auto * diff_y = ctx . Input < Tensor > ( framework : : GradVarName ( " Out " ) ) ;
auto * diff_x = ctx . Output < Tensor > ( framework : : GradVarName ( " X " ) ) ;
@ -169,47 +151,41 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
std : : vector < int > diff_dst_tz = framework : : vectorize2int ( diff_y - > dims ( ) ) ;
// diff_dst and src dims should be the same
auto src_format =
diff_dst_tz . size ( ) = = 2 ? mkldnn : : memory : : format : : nc : x - > format ( ) ;
auto diff_y_format =
diff_dst_tz . size ( ) = = 2 ? mkldnn : : memory : : format : : nc : diff_y - > format ( ) ;
auto diff_dst_md = platform : : MKLDNNMemDesc (
diff_dst_tz , platform : : MKLDNNGetDataType < T > ( ) , diff_y_format ) ;
std : : string key = platform : : MKLDNNHandler : : GetHash (
diff_dst_tz , std : : to_string ( algorithm ) + std : : to_string ( alpha ) +
std : : to_string ( beta ) + ctx . op ( ) . Input ( " X " ) ) ;
std : : string key = platform : : ActivationMKLDNNHandler : : GetHash (
diff_dst_tz , algorithm , src_format , alpha , beta , ctx . op ( ) . Input ( " X " ) ) ;
const std : : string key_src_data = key + " @eltwise_fwd_src_data " ;
const std : : string key_src_layout = key + " @eltwise_fwd_src_layout " ;
// Get Data from FWD op
const auto p_src_layout =
std : : static_pointer_cast < memory : : format > ( dev_ctx . GetBlob ( key_src_layout ) ) ;
const auto p_src_data =
std : : static_pointer_cast < T * > ( dev_ctx . GetBlob ( key_src_data ) ) ;
key + = std : : to_string ( * p_src_layout ) ;
const std : : string key_src_mem = key + " @eltwise_fwd_src_mem " ;
auto src_memory =
std : : static_pointer_cast < mkldnn : : memory > ( dev_ctx . GetBlob ( key_src_mem ) ) ;
PADDLE_ENFORCE ( src_memory ! = nullptr ,
" Fail to find src_memory in device context " ) ;
src_memory - > set_data_handle ( * p_src_data ) ;
auto src_md = platform : : MKLDNNMemDesc (
diff_dst_tz , platform : : MKLDNNGetDataType < T > ( ) , src_format ) ;
platform : : ActivationMKLDNNHandler handler ( dev_ctx , mkldnn_engine , key ) ;
auto src_memory_p = handler . AcquireSrcMemory ( src_md , to_void_cast < T > ( x_data ) ) ;
auto diff_dst_memory_p =
handler . AcquireDiffDstMemory ( diff_dst_md , to_void_cast < T > ( diff_y_data ) ) ;
auto activation_backward_pd =
handler . AcquireActivationBackwardPrimitiveDescriptor (
algorithm , diff_dst_md , src_memory - > get_primitive_desc ( ) . desc ( ) ,
algorithm , diff_dst_md , src_memory_p - > get_primitive_desc ( ) . desc ( ) ,
alpha , beta ) ;
auto diff_src_memory_p =
handler . AcquireDiffSrcMemoryFromPrimitive ( diff_x_data ) ;
auto activation_backward_p = handler . AcquireActivationBackward (
diff_src_memory_p , diff_dst_memory_p , src_memory ) ;
diff_src_memory_p , diff_dst_memory_p , src_memory _p ) ;
// push primitive to stream and wait until it's executed
std : : vector < primitive > pipeline ;