@ -15,6 +15,7 @@ limitations under the License. */
# include "paddle/fluid/framework/data_layout_transform.h"
# include "paddle/fluid/operators/pool_op.h"
# include "paddle/fluid/platform/mkldnn_helper.h"
# include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
@ -29,23 +30,23 @@ using mkldnn::stream;
using platform : : to_void_cast ;
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std : : string gethash ( const memory : : dims & input_dims ,
const std : : string & pooling_type ,
const std : : vector < int > & ksize ,
const std : : vector < int > & strides ,
const std : : vector < int > & paddings ,
const memory : : data_type & dt ,
const std : : string & suffix ) {
auto dims2str = [ ] ( const memory : : dims & operand_dims ) {
std : : string dstr = " " ;
for ( size_t i = 0 ; i < operand_dims . size ( ) ; + + i ) {
dstr + = std : : to_string ( operand_dims [ i ] ) + " - " ;
}
return dstr ;
} ;
return dims2str ( input_dims ) + dims2str ( ksize ) + dims2str ( strides ) +
dims2str ( paddings ) + std : : to_string ( dt ) + pooling_type + suffix ;
std : : string CreateKey ( const paddle : : framework : : ExecutionContext & ctx ,
const memory : : dims & input_dims ,
const std : : string & pooling_type ,
const std : : vector < int > & ksize ,
const std : : vector < int > & strides ,
const std : : vector < int > & paddings ,
const memory : : data_type & dt , const std : : string & suffix ) {
std : : string key ;
key . reserve ( platform : : MKLDNNHandler : : MaxKeyLength ) ;
platform : : MKLDNNHandler : : AppendKeyDims ( & key , input_dims ) ;
platform : : MKLDNNHandler : : AppendKey ( & key , pooling_type ) ;
platform : : MKLDNNHandler : : AppendKeyVec ( & key , ksize ) ;
platform : : MKLDNNHandler : : AppendKeyVec ( & key , strides ) ;
platform : : MKLDNNHandler : : AppendKeyVec ( & key , paddings ) ;
platform : : MKLDNNHandler : : AppendKey ( & key , std : : to_string ( dt ) ) ;
platform : : MKLDNNHandler : : AppendKey ( & key , suffix ) ;
return key ;
}
static inline int ComputeCeiledOutput ( int input_size , int kernel_size ,
@ -114,8 +115,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn : : memory : : data_type dt =
paddle : : framework : : ToMKLDNNDataType ( input - > type ( ) ) ;
const std : : string key = gethash( src_tz , pooling_type , ksize , strides ,
paddings , dt , ctx . op ( ) . Output ( " Out " ) ) ;
const std : : string key = CreateKey( ctx , src_tz , pooling_type , ksize , strides ,
paddings , dt , ctx . op ( ) . Output ( " Out " ) ) ;
const std : : string key_pool_p = key + " @pool_p " ;
const std : : string key_pool_pd = key + " @pool_pd " ;
const std : : string key_pool_src_mem_p = key + " @pool_src_mem_p " ;
@ -294,8 +295,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std : : string key =
gethash( diff_src_tz , pooling_type , ksize , strides , paddings ,
memory : : data_type : : f32 , ctx . op ( ) . Input ( " Out " ) ) ;
CreateKey( ctx , diff_src_tz , pooling_type , ksize , strides , paddings ,
memory : : data_type : : f32 , ctx . op ( ) . Input ( " Out " ) ) ;
const std : : string key_pool_bwd_p = key + " @pool_bwd_p " ;
const std : : string key_pool_diff_src_mem_p = key + " @pool_diff_src_mem_p " ;
const std : : string key_pool_diff_dst_mem_p = key + " @pool_diff_dst_mem_p " ;