@ -32,49 +32,58 @@ using mkldnn::softmax_forward;
using mkldnn : : stream ;
using platform : : to_void_cast ;
template < typename T >
class SoftmaxMKLDNNHandler : public platform : : MKLDNNHandler {
public :
SoftmaxMKLDNNHandler ( const platform : : MKLDNNDeviceContext & dev_ctx ,
SoftmaxMKLDNNHandler ( const std : : vector < int > & dims ,
const mkldnn : : memory : : format fmt ,
const platform : : MKLDNNDeviceContext & dev_ctx ,
mkldnn : : engine engine , const std : : string & base_key )
: platform : : MKLDNNHandler ( dev_ctx , engine , base_key ) { }
: platform : : MKLDNNHandler ( dev_ctx , engine , base_key ) ,
dims_ ( dims ) ,
fmt_ ( fmt ) { }
SoftmaxMKLDNNHandler (
std : : shared_ptr < mkldnn : : softmax_forward : : primitive_desc > softmax_pd ,
std : : shared_ptr < mkldnn : : softmax_backward : : primitive_desc > softmax_bwd_pd ,
const platform : : MKLDNNDeviceContext & dev_ctx , mkldnn : : engine engine ,
const std : : string & base_key )
SoftmaxMKLDNNHandler ( const std : : vector < int > & dims ,
const mkldnn : : memory : : format fmt ,
const mkldnn : : memory : : format diff_fmt ,
const platform : : MKLDNNDeviceContext & dev_ctx ,
mkldnn : : engine engine , const std : : string & base_key )
: platform : : MKLDNNHandler ( dev_ctx , engine , base_key ) ,
softmax_pd_ ( softmax_pd ) ,
softmax_bwd_pd_ ( softmax_bwd_pd ) {
dims_ ( dims ) ,
fmt_ ( fmt ) ,
diff_fmt_ ( diff_fmt ) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
// Key_common will allow to access FWD_PD from cache
key_ + = " -BWD " ;
}
std : : shared_ptr < softmax_forward : : primitive_desc >
AcquireSoftmaxPrimitiveDescriptor ( const softmax_forward : : desc & softmax_desc ,
const mkldnn : : engine & engine ) {
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std : : string key_softmax_pd = key_common_ + " @softmax_pd " ;
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function
// should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler<softmax_>
std : : shared_ptr < mkldnn : : memory > AcquireSrcMemory ( void * ptr ) {
return this - > AcquireMemory ( dims_ , platform : : MKLDNNGetDataType < T > ( ) , fmt_ ,
ptr , " @user_src_mem_p " ) ;
}
softmax_pd_ = std : : static_pointer_cast < softmax_forward : : primitive_desc > (
dev_ctx_ . GetBlob ( key_softmax_pd ) ) ;
if ( softmax_pd_ = = nullptr ) {
static std : : mutex acquire_barrier ;
std : : lock_guard < std : : mutex > block_threads_until_finish_this_job (
acquire_barrier ) ;
softmax_pd_ = std : : static_pointer_cast < softmax_forward : : primitive_desc > (
dev_ctx_ . GetBlob ( key_softmax_pd ) ) ;
if ( softmax_pd_ = = nullptr ) {
softmax_pd_ . reset (
new softmax_forward : : primitive_desc ( softmax_desc , engine ) ) ;
dev_ctx_ . SetBlob ( key_softmax_pd , softmax_pd_ ) ;
std : : shared_ptr < mkldnn : : memory > AcquireDstMemory ( void * ptr ) {
return this - > AcquireMemory ( dims_ , platform : : MKLDNNGetDataType < T > ( ) , fmt_ ,
ptr , " @user_dst_mem_p " ) ;
}
std : : shared_ptr < mkldnn : : memory > AcquireDiffDstMemory ( void * ptr ) {
return this - > AcquireMemory ( dims_ , platform : : MKLDNNGetDataType < T > ( ) ,
diff_fmt_ , ptr , " @user_diff_dst_mem_p " ) ;
}
std : : shared_ptr < mkldnn : : memory > AcquireDiffSrcMemory ( void * ptr ) {
return this - > AcquireMemory ( dims_ , platform : : MKLDNNGetDataType < T > ( ) ,
diff_fmt_ , ptr , " @user_diff_src_mem_p " ) ;
}
return softmax_pd_ ;
std : : shared_ptr < mkldnn : : memory > AcquireDstMemoryFromPrimitive ( void * ptr ) {
this - > AcquireSoftmaxPrimitiveDescriptor ( ) ;
return this - > AcquireMemoryFromPrimitive ( fwd_pd_ - > dst_primitive_desc ( ) , ptr ,
" @dst_mem_p " ) ;
}
std : : shared_ptr < mkldnn : : softmax_forward > AcquireSoftmax (
@ -86,8 +95,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_p = std : : static_pointer_cast < mkldnn : : softmax_forward > (
dev_ctx_ . GetBlob ( prim_key ) ) ;
if ( softmax_p = = nullptr ) {
this - > AcquireSoftmaxPrimitiveDescriptor ( ) ;
softmax_p = std : : make_shared < mkldnn : : softmax_forward > (
* softmax _pd_, * ( static_cast < mkldnn : : memory * > ( src_memory_p . get ( ) ) ) ,
* fwd _pd_, * ( static_cast < mkldnn : : memory * > ( src_memory_p . get ( ) ) ) ,
* ( static_cast < mkldnn : : memory * > ( dst_memory_p . get ( ) ) ) ) ;
dev_ctx_ . SetBlob ( prim_key , softmax_p ) ;
}
@ -103,8 +113,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_bwd_p = std : : static_pointer_cast < mkldnn : : softmax_backward > (
dev_ctx_ . GetBlob ( prim_key ) ) ;
if ( softmax_bwd_p = = nullptr ) {
auto data_softmax_md =
mkldnn : : memory : : desc ( dims_ , platform : : MKLDNNGetDataType < T > ( ) , fmt_ ) ;
auto diff_softmax_md = mkldnn : : memory : : desc (
dims_ , platform : : MKLDNNGetDataType < T > ( ) , diff_fmt_ ) ;
// TODO(jczaja): Add support for other axes
auto softmax_bwd_desc = softmax_backward : : desc (
diff_softmax_md , data_softmax_md , 1 /* dim: C*/ ) ;
this - > AcquireSoftmaxPrimitiveDescriptor ( ) ;
auto softmax_bwd_pd = mkldnn : : softmax_backward : : primitive_desc (
softmax_bwd_desc , engine_ , * fwd_pd_ ) ;
softmax_bwd_p = std : : make_shared < mkldnn : : softmax_backward > (
* softmax_bwd_pd_ , * dst_memory_p , * diff_dst_memory_p ,
softmax_bwd_pd , * dst_memory_p , * diff_dst_memory_p ,
* diff_src_memory_p ) ;
dev_ctx_ . SetBlob ( prim_key , softmax_bwd_p ) ;
}
@ -112,9 +133,41 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
return softmax_bwd_p ;
}
protected :
void AcquireSoftmaxPrimitiveDescriptor ( void ) {
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std : : string key_softmax_pd = key_common_ + " @softmax_pd " ;
fwd_pd_ = std : : static_pointer_cast < softmax_forward : : primitive_desc > (
dev_ctx_ . GetBlob ( key_softmax_pd ) ) ;
if ( fwd_pd_ = = nullptr ) {
static std : : mutex acquire_barrier ;
std : : lock_guard < std : : mutex > block_threads_until_finish_this_job (
acquire_barrier ) ;
fwd_pd_ = std : : static_pointer_cast < softmax_forward : : primitive_desc > (
dev_ctx_ . GetBlob ( key_softmax_pd ) ) ;
if ( fwd_pd_ = = nullptr ) {
// TODO(jczaja): Make it working along chosen axis and for
// forward_training
// Normalization is made after innermost dimension eg. C out of NC
auto md =
mkldnn : : memory : : desc ( dims_ , platform : : MKLDNNGetDataType < T > ( ) , fmt_ ) ;
auto softmax_desc =
softmax_forward : : desc ( prop_kind : : forward_scoring , md , 1 /*dim: C*/ ) ;
fwd_pd_ . reset (
new softmax_forward : : primitive_desc ( softmax_desc , engine_ ) ) ;
dev_ctx_ . SetBlob ( key_softmax_pd , fwd_pd_ ) ;
}
}
}
private :
std : : shared_ptr < mkldnn : : softmax_forward : : primitive_desc > softmax_pd_ ;
std : : shared_ptr < mkldnn : : softmax_backward : : primitive_desc > softmax_bwd_pd_ ;
std : : vector < int > dims_ ;
mkldnn : : memory : : format fmt_ ;
mkldnn : : memory : : format diff_fmt_ ;
std : : shared_ptr < mkldnn : : softmax_forward : : primitive_desc > fwd_pd_ ;
} ;
template < typename T >
@ -154,21 +207,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const std : : string key =
platform : : MKLDNNHandler : : GetHash ( softmax_tz , ctx . op ( ) . Output ( " Out " ) ) ;
SoftmaxMKLDNNHandler handler ( dev_ctx , mkldnn_engine , key ) ;
// Currently only NC data format is supported
auto softmax_md = MKLDNNMemDesc (
{ softmax_tz } , platform : : MKLDNNGetDataType < T > ( ) , memory : : format : : nc ) ;
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward : : desc ( prop_kind : : forward_scoring ,
softmax_md , 1 /*dim: C*/ ) ;
auto softmax_pd =
handler . AcquireSoftmaxPrimitiveDescriptor ( softmax_desc , mkldnn_engine ) ;
SoftmaxMKLDNNHandler < T > handler ( softmax_tz , mkldnn : : memory : : format : : nc ,
dev_ctx , mkldnn_engine , key ) ;
// Currently only NC data format is supported
auto softmax_src_memory_p =
handler . AcquireSrcMemory ( softmax_md, to_void_cast< T > ( input_data ) ) ;
handler . AcquireSrcMemory ( to_void_cast < T > ( input_data ) ) ;
auto softmax_dst_memory_p =
handler . AcquireDstMemory ( softmax_md , to_void_cast < T > ( output_data ) ) ;
handler . AcquireDstMemoryFromPrimitive ( to_void_cast < T > ( output_data ) ) ;
auto softmax_p =
handler . AcquireSoftmax ( softmax_dst_memory_p , softmax_src_memory_p ) ;
@ -241,25 +287,16 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
// TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format
auto data_softmax_md = MKLDNNMemDesc (
{ softmax_tz } , platform : : MKLDNNGetDataType < T > ( ) , memory : : format : : nc ) ;
auto diff_softmax_md = MKLDNNMemDesc (
{ softmax_tz } , platform : : MKLDNNGetDataType < T > ( ) , memory : : format : : nc ) ;
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_bwd_desc =
softmax_backward : : desc ( diff_softmax_md , data_softmax_md , 1 /* dim: C*/ ) ;
auto softmax_bwd_pd =
std : : make_shared < mkldnn : : softmax_backward : : primitive_desc > (
softmax_bwd_desc , mkldnn_engine , * softmax_pd ) ;
SoftmaxMKLDNNHandler handler ( softmax_pd , softmax_bwd_pd , dev_ctx ,
SoftmaxMKLDNNHandler < T > handler ( softmax_tz , mkldnn : : memory : : format : : nc ,
mkldnn : : memory : : format : : nc , dev_ctx ,
mkldnn_engine , key ) ;
auto dst_memory_p =
handler . AcquireDstMemory ( data_softmax_md , to_void_cast < T > ( dst_data ) ) ;
auto diff_dst_memory_p = handler . AcquireDiffDstMemory (
diff_softmax_md, to_void_cast < T > ( diff_dst_ptr ) ) ;
auto diff_src_memory_p = handler . AcquireDiffSrcMemory (
diff_softmax_md, to_void_cast < T > ( diff_src_ptr ) ) ;
auto dst_memory_p = handler . AcquireDstMemory ( to_void_cast < T > ( dst_data ) ) ;
auto diff_dst_memory_p =
handler . AcquireDiffDstMemory ( to_void_cast < T > ( diff_dst_ptr ) ) ;
auto diff_src_memory_p =
handler. AcquireDiffSrcMemory ( to_void_cast < T > ( diff_src_ptr ) ) ;
// Get primitve from device context
auto softmax_bwd_p = handler . AcquireSoftmaxBackward (