@ -1160,18 +1160,24 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
scale_data , mask ) ;
}
mkldnn : : primitive_attr CreatePostOps ( bool fuse_relu , bool fuse_residual_conn ,
bool fuse_brelu ,
float fuse_brelu_threshold ) const {
mkldnn : : primitive_attr CreatePostOps (
bool fuse_relu , bool fuse_residual_conn , bool fuse_brelu ,
float fuse_brelu_threshold ,
const std : : vector < float > output_shift_scale = { } ,
float sum_scale = 1.0f ) const {
mkldnn : : primitive_attr conv_attr ;
mkldnn : : post_ops post_operations ;
if ( output_shift_scale . size ( ) > 0 ) {
int mask = output_shift_scale . size ( ) > 1 ? 1 < < 1 : 0 ;
conv_attr . set_output_scales ( mask , output_shift_scale ) ;
}
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if ( fuse_residual_conn ) {
post_operations . append_sum ( 1.0f ) ;
post_operations . append_sum ( sum_scale ) ;
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
@ -1202,7 +1208,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
const std : : vector < int > & paddings , const mkldnn : : engine & engine ,
const bool fuse_relu , const bool fuse_residual_conn ,
const bool fuse_brelu , const float fuse_brelu_threshold ,
mkldnn : : prop_kind fwd_prop_kind ) {
mkldnn : : prop_kind fwd_prop_kind ,
const std : : vector < float > output_shift_scale = { } ,
const float sum_scale = 1.0f ) {
// Conv PD has to be passed to Grad op that
// may be exxecuted by diffrent thread, hence
// for that one we use key that does not contain TID
@ -1232,8 +1240,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
src , weights , dst , stride_dims , padding_dims ,
padding_dims , mkldnn : : padding_kind : : zero ) ;
mkldnn : : primitive_attr conv_attr = CreatePostOps (
fuse_relu , fuse_residual_conn , fuse_brelu , fuse_brelu_threshold ) ;
mkldnn : : primitive_attr conv_attr =
CreatePostOps ( fuse_relu , fuse_residual_conn , fuse_brelu ,
fuse_brelu_threshold , output_shift_scale , sum_scale ) ;
conv_pd_ . reset ( new typename forward_t : : primitive_desc (
conv_desc , conv_attr , engine ) ) ;
@ -1393,10 +1402,10 @@ template <typename T>
static void SetDstMemoryHandler (
const framework : : ExecutionContext & ctx , framework : : Tensor * output ,
const std : : shared_ptr < ConvMKLDNNHandler > & handler ,
std : : shared_ptr < mkldnn : : memory > * dst_memory_p ) {
std : : shared_ptr < mkldnn : : memory > dst_memory_p ) {
T * output_data =
output - > mutable_data < T > ( ctx . GetPlace ( ) , handler - > GetDstMemorySize ( ) ) ;
( * dst_memory_p ) - > set_data_handle ( to_void_cast < T > ( output_data ) ) ;
dst_memory_p - > set_data_handle ( to_void_cast < T > ( output_data ) ) ;
}
template < typename T >