@ -18,6 +18,26 @@ limitations under the License. */
namespace paddle {
namespace operators {
using mkldnn : : memory ; // Note: paddle has also "memory" namespace
using mkldnn : : pooling_forward ;
using mkldnn : : pooling_backward ;
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std : : string gethash ( memory : : dims & input_dims , std : : string & pooling_type ,
std : : vector < int > & ksize , std : : vector < int > & strides ,
std : : vector < int > & paddings , std : : string suffix ) {
auto dims2str = [ ] ( memory : : dims & operand_dims ) {
std : : string dstr = " " ;
for ( size_t i = 0 ; i < operand_dims . size ( ) ; + + i ) {
dstr + = std : : to_string ( operand_dims [ i ] ) + " - " ;
}
return dstr ;
} ;
return dims2str ( input_dims ) + dims2str ( ksize ) + dims2str ( strides ) +
dims2str ( paddings ) + pooling_type + suffix ;
}
template < typename T >
class PoolMKLDNNOpKernel : public paddle : : framework : : OpKernel < T > {
public :
@ -34,10 +54,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when saving info into device context
const std : : string key = ctx . op ( ) . Output ( " Out " ) ;
const std : : string key_pool_pd = key + " @pool_pd " ;
const std : : string key_pool_workspace_memory =
key + " @pool_workspace_memory " ;
std : : string pooling_type = ctx . Attr < std : : string > ( " pooling_type " ) ;
std : : vector < int > ksize = ctx . Attr < std : : vector < int > > ( " ksize " ) ;
@ -63,13 +79,28 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std : : vector < int > src_tz = paddle : : framework : : vectorize2int ( input - > dims ( ) ) ;
std : : vector < int > dst_tz = paddle : : framework : : vectorize2int ( output - > dims ( ) ) ;
const std : : string key = gethash ( src_tz , pooling_type , ksize , strides ,
paddings , ctx . op ( ) . Output ( " Out " ) ) ;
const std : : string key_pool_p = key + " @pool_p " ;
const std : : string key_pool_pd = key + " @pool_pd " ;
const std : : string key_pool_src_mem_p = key + " @pool_src_mem_p " ;
const std : : string key_pool_dst_mem_p = key + " @pool_dst_mem_p " ;
const std : : string key_pool_workspace_memory =
key + " @pool_workspace_memory " ;
auto pool_p =
std : : static_pointer_cast < pooling_forward > ( dev_ctx . GetBlob ( key_pool_p ) ) ;
if ( pool_p = = nullptr ) {
// TODO(pzelazko-intel): support more formats
auto src_md = platform : : MKLDNNMemDesc ( src_tz , mkldnn : : memory : : f32 ,
auto src_md =
platform : : MKLDNNMemDesc ( src_tz , platform : : MKLDNNGetDataType < T > ( ) ,
mkldnn : : memory : : format : : nchw ) ;
auto dst_md = platform : : MKLDNNMemDesc ( dst_tz , mkldnn : : memory : : f32 ,
auto dst_md =
platform : : MKLDNNMemDesc ( dst_tz , platform : : MKLDNNGetDataType < T > ( ) ,
mkldnn : : memory : : format : : nchw ) ;
std : : shared_ptr < mkldnn : : pooling_forward : : primitive_desc > pool_pd =
std : : shared_ptr < pooling_forward : : primitive_desc > pool_pd =
CreatePrimitiveDesc ( src_md , dst_md , strides , paddings , ksize ,
pooling_type , mkldnn_engine ) ;
@ -82,18 +113,37 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_workspace_memory to be referred in backward path
dev_ctx . SetBlob ( key_pool_workspace_memory , workspace_memory ) ;
auto src_memory =
mkldnn : : memory ( { src_md , mkldnn_engine } ,
auto pool_ src_memory_p = std : : make_shared < memory > (
memory : : primitive_desc { src_md , mkldnn_engine } ,
static_cast < void * > ( const_cast < T * > ( input_data ) ) ) ;
auto dst_memory =
mkldnn : : memory ( { dst_md , mkldnn_engine } ,
static_cast < void * > ( const_cast < T * > ( output_data ) ) ) ;
dev_ctx . SetBlob ( key_pool_src_mem_p , pool_src_memory_p ) ;
auto pool_prim = mkldnn : : pooling_forward ( * pool_pd , src_memory , dst_memory ,
auto pool_dst_memory_p = std : : make_shared < memory > (
memory : : primitive_desc { dst_md , mkldnn_engine } ,
static_cast < void * > ( output_data ) ) ;
dev_ctx . SetBlob ( key_pool_dst_mem_p , pool_dst_memory_p ) ;
pool_p = std : : make_shared < pooling_forward > (
* pool_pd , * ( pool_src_memory_p . get ( ) ) , * ( pool_dst_memory_p . get ( ) ) ,
* workspace_memory ) ;
dev_ctx . SetBlob ( key_pool_p , pool_p ) ;
} else {
// Primitives already exist
auto pool_src_memory_p =
std : : static_pointer_cast < memory > ( dev_ctx . GetBlob ( key_pool_src_mem_p ) ) ;
PADDLE_ENFORCE ( pool_src_memory_p ! = nullptr ,
" Fail to find pooling src mem_p in device context " ) ;
auto pool_dst_memory_p =
std : : static_pointer_cast < memory > ( dev_ctx . GetBlob ( key_pool_dst_mem_p ) ) ;
PADDLE_ENFORCE ( pool_dst_memory_p ! = nullptr ,
" Fail to find pooling dst mem_p in device context " ) ;
pool_src_memory_p - > set_data_handle (
reinterpret_cast < void * > ( const_cast < T * > ( input_data ) ) ) ;
pool_dst_memory_p - > set_data_handle ( output_data ) ;
}
// push primitive to stream and wait until it's executed
std : : vector < mkldnn : : primitive > pipeline { pool_prim } ;
std : : vector < mkldnn : : primitive > pipeline { * ( pool_p . get ( ) ) } ;
mkldnn : : stream ( mkldnn : : stream : : kind : : eager ) . submit ( pipeline ) . wait ( ) ;
}
@ -120,8 +170,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn : : memory : : primitive_desc workspace_md =
pooling_type = = " max "
? pool_pd - > workspace_primitive_desc ( )
: mkldnn : : memory : : primitive_desc (
{ { } , mkldnn : : memory : : f32 , mkldnn : : memory : : format : : nchw } ,
: mkldnn : : memory : : primitive_desc ( { { } ,
platform : : MKLDNNGetDataType < T > ( ) ,
mkldnn : : memory : : format : : nchw } ,
engine ) ;
auto p_workspace_memory = new mkldnn : : memory ( workspace_md ) ;
@ -140,13 +191,6 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor * out_grad = ctx . Input < Tensor > ( framework : : GradVarName ( " Out " ) ) ;
Tensor * in_x_grad = ctx . Output < Tensor > ( framework : : GradVarName ( " X " ) ) ;
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std : : string key = ctx . op ( ) . Input ( " Out " ) ;
const std : : string key_pool_pd = key + " @pool_pd " ;
const std : : string key_pool_workspace_memory =
key + " @pool_workspace_memory " ;
std : : string pooling_type = ctx . Attr < std : : string > ( " pooling_type " ) ;
std : : vector < int > ksize = ctx . Attr < std : : vector < int > > ( " ksize " ) ;
std : : vector < int > strides = ctx . Attr < std : : vector < int > > ( " strides " ) ;
@ -171,11 +215,26 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std : : vector < int > diff_dst_tz =
paddle : : framework : : vectorize2int ( out_grad - > dims ( ) ) ;
auto diff_src_md = platform : : MKLDNNMemDesc ( diff_src_tz , mkldnn : : memory : : f32 ,
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std : : string key = gethash ( diff_src_tz , pooling_type , ksize , strides ,
paddings , ctx . op ( ) . Input ( " Out " ) ) ;
const std : : string key_pool_bwd_p = key + " @pool_bwd_p " ;
const std : : string key_pool_diff_src_mem_p = key + " @pool_diff_src_mem_p " ;
const std : : string key_pool_diff_dst_mem_p = key + " @pool_diff_dst_mem_p " ;
const std : : string key_pool_pd = key + " @pool_pd " ;
const std : : string key_pool_workspace_memory =
key + " @pool_workspace_memory " ;
auto pool_bwd_p = std : : static_pointer_cast < pooling_backward > (
dev_ctx . GetBlob ( key_pool_bwd_p ) ) ;
if ( pool_bwd_p = = nullptr ) {
auto diff_src_md =
platform : : MKLDNNMemDesc ( diff_src_tz , platform : : MKLDNNGetDataType < T > ( ) ,
mkldnn : : memory : : format : : nchw ) ;
auto diff_dst_md = platform : : MKLDNNMemDesc ( diff_dst_tz , mkldnn : : memory : : f32 ,
auto diff_dst_md =
platform : : MKLDNNMemDesc ( diff_dst_tz , platform : : MKLDNNGetDataType < T > ( ) ,
mkldnn : : memory : : format : : nchw ) ;
// Retrieve pool_pd/pool_workspace_memory from device context
auto pool_pd =
std : : static_pointer_cast < mkldnn : : pooling_forward : : primitive_desc > (
@ -188,6 +247,15 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE ( workspace_memory ! = nullptr ,
" Fail to find workspace_memory in device context " ) ;
auto pool_diff_src_memory_p = std : : make_shared < memory > ( memory (
{ diff_src_md , mkldnn_engine } , static_cast < void * > ( in_x_grad_data ) ) ) ;
dev_ctx . SetBlob ( key_pool_diff_src_mem_p , pool_diff_src_memory_p ) ;
auto pool_diff_dst_memory_p = std : : make_shared < memory > (
memory ( { diff_dst_md , mkldnn_engine } ,
static_cast < void * > ( const_cast < T * > ( out_grad_data ) ) ) ) ;
dev_ctx . SetBlob ( key_pool_diff_dst_mem_p , pool_diff_dst_memory_p ) ;
auto pool_bwd_desc = mkldnn : : pooling_backward : : desc (
pooling_type = = " max " ? mkldnn : : algorithm : : pooling_max
: mkldnn : : algorithm : : pooling_avg ,
@ -196,18 +264,27 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto pool_bwd_pd = mkldnn : : pooling_backward : : primitive_desc (
pool_bwd_desc , mkldnn_engine , * pool_pd ) ;
auto diff_src_memory =
mkldnn : : memory ( { diff_src_md , mkldnn_engine } ,
static_cast < void * > ( const_cast < T * > ( in_x_grad_data ) ) ) ;
auto diff_dst_memory =
mkldnn : : memory ( { diff_dst_md , mkldnn_engine } ,
static_cast < void * > ( const_cast < T * > ( out_grad_data ) ) ) ;
auto bwd_prim = mkldnn : : pooling_backward (
pool_bwd_pd , diff_dst_memory , * workspace_memory , diff_src_memory ) ;
pool_bwd_p = std : : make_shared < pooling_backward > (
pool_bwd_pd , * ( pool_diff_dst_memory_p . get ( ) ) , * workspace_memory ,
* ( pool_diff_src_memory_p ) ) ;
dev_ctx . SetBlob ( key_pool_bwd_p , pool_bwd_p ) ;
} else {
// Primitives already exist
auto pool_diff_src_memory_p = std : : static_pointer_cast < memory > (
dev_ctx . GetBlob ( key_pool_diff_src_mem_p ) ) ;
PADDLE_ENFORCE ( pool_diff_src_memory_p ! = nullptr ,
" Fail to find pooling src mem_p in device context " ) ;
auto pool_diff_dst_memory_p = std : : static_pointer_cast < memory > (
dev_ctx . GetBlob ( key_pool_diff_dst_mem_p ) ) ;
PADDLE_ENFORCE ( pool_diff_dst_memory_p ! = nullptr ,
" Fail to find pooling dst mem_p in device context " ) ;
pool_diff_src_memory_p - > set_data_handle (
reinterpret_cast < void * > ( in_x_grad_data ) ) ;
pool_diff_dst_memory_p - > set_data_handle ( const_cast < T * > ( out_grad_data ) ) ;
}
// push primitive to stream and wait until it's executed
std : : vector < mkldnn : : primitive > pipeline { bwd_prim } ;
std : : vector < mkldnn : : primitive > pipeline { * ( pool_bwd_p . get ( ) ) } ;
mkldnn : : stream ( mkldnn : : stream : : kind : : eager ) . submit ( pipeline ) . wait ( ) ;
} // Compute()
} ;