@ -53,25 +53,60 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
" Softmax input and output dimensions should match " ) ;
// Same memory descriptor to be used for input and output
memory : : dims softmax_tz = { src_tz [ 0 ] , src_tz [ 1 ] } ;
// Currently only supports NC data format
// TODO(jczaja-intel): support more formats
auto softmax_md =
MKLDNNMemDesc ( { softmax_tz } , memory : : f32 , memory : : format : : nc ) ;
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward : : desc ( prop_kind : : forward_scoring ,
softmax_md , 1 /*dim: C*/ ) ;
// create memory primitives
auto softmax_src_memory =
memory ( { softmax_md , mkldnn_engine } ,
static_cast < void * > ( const_cast < T * > ( input_data ) ) ) ;
auto softmax_dst_memory =
memory ( { softmax_md , mkldnn_engine } ,
static_cast < void * > ( const_cast < T * > ( output_data ) ) ) ;
auto softmax_prim_desc =
softmax_forward : : primitive_desc ( softmax_desc , mkldnn_engine ) ;
auto softmax = softmax_forward ( softmax_prim_desc , softmax_src_memory ,
softmax_dst_memory ) ;
std : : vector < primitive > pipeline { softmax } ;
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Each MKLDNN operator may have diffrent hashing function
auto gethash = [ ] ( memory : : dims & operand_dims ) {
return std : : string ( std : : to_string ( operand_dims [ 0 ] ) + " - " +
std : : to_string ( operand_dims [ 1 ] ) ) ;
} ;
const std : : string key = gethash ( softmax_tz ) ;
const std : : string key_softmax_p = key + " @softmax_p " ;
const std : : string key_softmax_src_mem_p = key + " @softmax_src_mem_p " ;
const std : : string key_softmax_dst_mem_p = key + " @softmax_dst_mem_p " ;
std : : shared_ptr < void > softmax_p = dev_ctx . GetBlob ( key_softmax_p ) ;
if ( softmax_p = = nullptr ) {
// Currently only NC data format is supported
auto softmax_md =
MKLDNNMemDesc ( { softmax_tz } , memory : : f32 , memory : : format : : nc ) ;
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward : : desc ( prop_kind : : forward_scoring ,
softmax_md , 1 /*dim: C*/ ) ;
// create memory primitives
auto softmax_src_memory_p = std : : make_shared < memory > (
memory : : primitive_desc { softmax_md , mkldnn_engine } ,
static_cast < void * > ( const_cast < T * > ( input_data ) ) ) ;
dev_ctx . SetBlob ( key_softmax_src_mem_p , softmax_src_memory_p ) ;
auto softmax_dst_memory_p = std : : make_shared < memory > (
memory : : primitive_desc { softmax_md , mkldnn_engine } ,
static_cast < void * > ( output_data ) ) ;
dev_ctx . SetBlob ( key_softmax_dst_mem_p , softmax_dst_memory_p ) ;
auto softmax_forward_pd =
std : : make_shared < softmax_forward : : primitive_desc > ( softmax_desc ,
mkldnn_engine ) ;
softmax_p = std : : make_shared < softmax_forward > (
* ( softmax_forward_pd . get ( ) ) ,
* ( static_cast < memory * > ( softmax_src_memory_p . get ( ) ) ) ,
* ( static_cast < memory * > ( softmax_dst_memory_p . get ( ) ) ) ) ;
dev_ctx . SetBlob ( key_softmax_p , softmax_p ) ;
} else {
// Primitives already exist
auto src_memory_p = std : : static_pointer_cast < memory > (
dev_ctx . GetBlob ( key_softmax_src_mem_p ) ) ;
PADDLE_ENFORCE ( src_memory_p ! = nullptr ,
" Fail to find softmax src mem_p in device context " ) ;
auto dst_memory_p = std : : static_pointer_cast < memory > (
dev_ctx . GetBlob ( key_softmax_dst_mem_p ) ) ;
PADDLE_ENFORCE ( dst_memory_p ! = nullptr ,
" Fail to find softmax dst mem_p in device context " ) ;
src_memory_p - > set_data_handle (
reinterpret_cast < void * > ( const_cast < T * > ( input_data ) ) ) ;
dst_memory_p - > set_data_handle ( output_data ) ;
}
std : : vector < primitive > pipeline {
* ( static_cast < softmax_forward : : primitive * > ( softmax_p . get ( ) ) ) } ;
stream ( stream : : kind : : eager ) . submit ( pipeline ) . wait ( ) ;
const bool is_test = ctx . Attr < bool > ( " is_test " ) ;