diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index 1135d672f7..8426348a11 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -33,17 +33,18 @@ using mkldnn::stream; using platform::to_void_cast; template -class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { +class SoftmaxMKLDNNHandler + : public platform::MKLDNNHandlerT { public: SoftmaxMKLDNNHandler(const std::vector& dims, const MKLDNNMemoryFormat fmt, const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, const std::string& uniq_name) - : platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(), - platform::CreateKey(dims, uniq_name)), - place_(cpu_place), - fwd_pd_(nullptr), - bwd_pd_(nullptr) { + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dims, uniq_name)) { this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt); } @@ -52,11 +53,10 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { const MKLDNNMemoryFormat diff_fmt, const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, const std::string& uniq_name) - : platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(), - platform::CreateKey(dims, uniq_name)), - place_(cpu_place), - fwd_pd_(nullptr), - bwd_pd_(nullptr) { + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dims, uniq_name)) { // If we are in Grad operatgor then update a key with BWD suffix to // distinguish from FWD memory primitives // Key_common will allow to access FWD_PD from cache @@ -64,58 +64,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { this->AcquireSoftmaxBackwardPrimitiveDescriptor(dims, fmt, diff_fmt); } - // TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function - // should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler - std::shared_ptr AcquireSrcMemory(const Tensor* input) { - const T* input_data = input->data(); - return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(), - to_void_cast(input_data), - "@src_mem_p"); - } - - // TODO(jczaja): Move to MKLDNNHandler as common code - std::shared_ptr AcquireDstMemory(framework::Tensor* output) { - T* ptr = output->mutable_data(place_, - fwd_pd_->dst_primitive_desc().get_size()); - return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, - "@dst_mem_p"); - } - - std::shared_ptr AcquireDstMemory(const Tensor* output) { - const T* output_data = output->data(); - return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_primitive_desc(), - to_void_cast(output_data), - "@bwd-dst_mem_p"); - } - - std::shared_ptr AcquireDiffDstMemory(const Tensor* diffdst) { - const T* ptr = diffdst->data(); - return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(), - to_void_cast(ptr), - "@diff_dst_mem_p"); - } - - std::shared_ptr AcquireDiffSrcMemory( - framework::Tensor* diffsrc) { - T* ptr = diffsrc->mutable_data( - place_, bwd_pd_->diff_src_primitive_desc().get_size()); - return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(), - ptr, "@diff_src_mem_p"); - } - std::shared_ptr AcquireSoftmax( std::shared_ptr dst_memory_p, std::shared_ptr src_memory_p) { /*Generate key*/ - auto prim_key = key_ + "@softmax_p"; + auto prim_key = this->key_ + "@softmax_p"; auto softmax_p = std::static_pointer_cast( - dev_ctx_.GetBlob(prim_key)); + this->dev_ctx_.GetBlob(prim_key)); if (softmax_p == nullptr) { softmax_p = std::make_shared( - *fwd_pd_, *(static_cast(src_memory_p.get())), + *this->fwd_pd_, *(static_cast(src_memory_p.get())), *(static_cast(dst_memory_p.get()))); - dev_ctx_.SetBlob(prim_key, softmax_p); + this->dev_ctx_.SetBlob(prim_key, softmax_p); } return softmax_p; @@ -125,13 +86,14 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr dst_memory_p, std::shared_ptr diff_dst_memory_p, std::shared_ptr diff_src_memory_p) { - auto prim_key = key_ + "@softmax_bwd_p"; + auto prim_key = this->key_ + "@softmax_bwd_p"; auto softmax_bwd_p = std::static_pointer_cast( - dev_ctx_.GetBlob(prim_key)); + this->dev_ctx_.GetBlob(prim_key)); if (softmax_bwd_p == nullptr) { softmax_bwd_p = std::make_shared( - *bwd_pd_, *dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p); - dev_ctx_.SetBlob(prim_key, softmax_bwd_p); + *this->bwd_pd_, *dst_memory_p, *diff_dst_memory_p, + *diff_src_memory_p); + this->dev_ctx_.SetBlob(prim_key, softmax_bwd_p); } return softmax_bwd_p; @@ -143,17 +105,17 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { // Softmax PD has to be passed to Grad op that // may be executed by diffrent thread, hence // for that one we use key that does not contain TID - const std::string key_softmax_pd = key_common_ + "@softmax_pd"; + const std::string key_softmax_pd = this->key_common_ + "@softmax_pd"; - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_softmax_pd)); - if (fwd_pd_ == nullptr) { + this->fwd_pd_ = std::static_pointer_cast( + this->dev_ctx_.GetBlob(key_softmax_pd)); + if (this->fwd_pd_ == nullptr) { static std::mutex acquire_barrier; std::lock_guard block_threads_until_finish_this_job( acquire_barrier); - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_softmax_pd)); - if (fwd_pd_ == nullptr) { + this->fwd_pd_ = std::static_pointer_cast( + this->dev_ctx_.GetBlob(key_softmax_pd)); + if (this->fwd_pd_ == nullptr) { // TODO(jczaja): Make it working along chosen axis and for // forward_training // Normalization is made after innermost dimension eg. C out of NC @@ -161,9 +123,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/); - fwd_pd_.reset( - new softmax_forward::primitive_desc(softmax_desc, engine_)); - dev_ctx_.SetBlob(key_softmax_pd, fwd_pd_); + this->fwd_pd_.reset( + new softmax_forward::primitive_desc(softmax_desc, this->engine_)); + this->dev_ctx_.SetBlob(key_softmax_pd, this->fwd_pd_); } } } @@ -172,12 +134,12 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { const std::vector& dims, const mkldnn::memory::format fmt, const mkldnn::memory::format diff_fmt) { // Fwd_PD_ has to exists when to create BWD_PD_ - PADDLE_ENFORCE_NOT_NULL(fwd_pd_); - const std::string key_bwd_pd = key_ + "@softmax_bwd_pd"; - bwd_pd_ = + PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_); + const std::string key_bwd_pd = this->key_ + "@softmax_bwd_pd"; + this->bwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_bwd_pd)); - if (bwd_pd_ == nullptr) { + this->dev_ctx_.GetBlob(key_bwd_pd)); + if (this->bwd_pd_ == nullptr) { auto data_softmax_md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); auto diff_softmax_md = mkldnn::memory::desc( @@ -185,16 +147,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { // TODO(jczaja): Add support for other axes auto backward_desc = softmax_backward::desc( diff_softmax_md, data_softmax_md, 1 /* dim: C*/); - bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc( - backward_desc, engine_, *fwd_pd_)); - dev_ctx_.SetBlob(key_bwd_pd, bwd_pd_); + this->bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc( + backward_desc, this->engine_, *this->fwd_pd_)); + this->dev_ctx_.SetBlob(key_bwd_pd, this->bwd_pd_); } } - - private: - platform::Place place_; - std::shared_ptr fwd_pd_; - std::shared_ptr bwd_pd_; }; template diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 7131266f3f..1ff568cef3 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/place.h" @@ -206,7 +207,7 @@ inline std::string CreateKey(ArgTypes&&... args) { std::string key; key.reserve(256); using expand_type = int[]; - expand_type{0, (AppendKey(&key, args), 0)...}; + expand_type{0, (AppendKey(&key, std::forward(args)), 0)...}; return key; } diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 911dffff42..1ebc89a8af 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -29,6 +29,90 @@ namespace platform { using user_function = std::function(const float*)>; using memory = mkldnn::memory; +template +class MKLDNNHandlerT { + public: + MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, + platform::Place cpu_place, const std::string& base_key) + : dev_ctx_(dev_ctx), + engine_(engine), + place_(cpu_place), + key_common_(base_key), + fwd_pd_(nullptr), + bwd_pd_(nullptr) { + if (platform::get_cur_mkldnn_session_id() != + platform::kMKLDNNSessionID_Default) { + key_ = key_common_; + } else { + key_ = key_common_ + "-t:" + ThreadIDasStr(); + } + } + + std::shared_ptr AcquireSrcMemory( + const framework::Tensor* input) { + const T* input_data = input->data(); + return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(), + to_void_cast(input_data), + "@src_mem_p"); + } + + std::shared_ptr AcquireDstMemory(framework::Tensor* output) { + T* ptr = output->mutable_data(place_, + fwd_pd_->dst_primitive_desc().get_size()); + return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, + "@dst_mem_p"); + } + + std::shared_ptr AcquireDstMemory( + const framework::Tensor* output) { + const T* output_data = output->data(); + return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_primitive_desc(), + to_void_cast(output_data), + "@bwd-dst_mem_p"); + } + + std::shared_ptr AcquireDiffDstMemory( + const framework::Tensor* diffdst) { + const T* ptr = diffdst->data(); + return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(), + to_void_cast(ptr), + "@diff_dst_mem_p"); + } + + std::shared_ptr AcquireDiffSrcMemory( + framework::Tensor* diffsrc) { + T* ptr = diffsrc->mutable_data( + place_, bwd_pd_->diff_src_primitive_desc().get_size()); + return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(), + ptr, "@diff_src_mem_p"); + } + + std::shared_ptr AcquireMemoryFromPrimitive( + mkldnn::memory::primitive_desc mdp, void* ptr, + const std::string& suffix) { + auto local_key = key_ + suffix; + auto mem_p = + std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); + if (mem_p == nullptr) { + mem_p = std::make_shared(mdp, ptr); + dev_ctx_.SetBlob(local_key, mem_p); + } else { + mem_p->set_data_handle(ptr); + } + return mem_p; + } + + protected: + const MKLDNNDeviceContext& dev_ctx_; + mkldnn::engine engine_; + platform::Place place_; + std::string key_; + std::string key_common_; + std::shared_ptr fwd_pd_; + std::shared_ptr bwd_pd_; +}; + +// TODO(grygielski) this class will be deleted later. class MKLDNNHandler { public: MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, @@ -255,7 +339,9 @@ class SumMKLDNNHandler : public MKLDNNHandler { }; template -class ActivationMKLDNNHandler : public MKLDNNHandler { +class ActivationMKLDNNHandler + : public MKLDNNHandlerT { public: ActivationMKLDNNHandler(const std::vector& dims, mkldnn::algorithm algorithm, float alpha, float beta, @@ -264,12 +350,11 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { platform::Place cpu_place, const std::string& unique_name) - : platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(), - platform::CreateKey(dims, algorithm, fmt, alpha, - beta, unique_name)), - place_(cpu_place), - fwd_pd_(nullptr), - bwd_pd_(nullptr) { + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dims, algorithm, fmt, alpha, beta, + unique_name)) { AcquireActivationPrimitiveDescriptor( is_test ? mkldnn::prop_kind::forward_inference : mkldnn::prop_kind::forward_training, @@ -284,76 +369,37 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { platform::Place cpu_place, const std::string& unique_name) - : platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(), - platform::CreateKey(dims, algorithm, fmt, alpha, - beta, unique_name)), - place_(cpu_place), - fwd_pd_(nullptr), - bwd_pd_(nullptr) { + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dims, algorithm, fmt, alpha, beta, + unique_name)) { AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind::forward_training, algorithm, dims, fmt, alpha, beta); AcquireActivationBackwardPrimitiveDescriptor(algorithm, dims, fmt, diff_fmt, alpha, beta); } - // TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this - // function - // should be moved as well eg. ActivationMKLDNNHandler -> - // MKLDNNHandler - std::shared_ptr AcquireSrcMemory( - const framework::Tensor* input) { - const T* input_data = input->data(); - return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(), - to_void_cast(input_data), - "@src_mem_p"); - } - std::shared_ptr AcquireBackwardSrcMemory( const framework::Tensor* input) { const T* input_data = input->data(); - return this->AcquireMemoryFromPrimitive(bwd_pd_->src_primitive_desc(), + return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_primitive_desc(), to_void_cast(input_data), "@bwd-src_mem_p"); } - // TODO(jczaja): Move to MKLDNNHandler as common code - std::shared_ptr AcquireDstMemory(framework::Tensor* output) { - T* ptr = output->mutable_data(place_, - fwd_pd_->dst_primitive_desc().get_size()); - return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, - "@dst_mem_p"); - } - - // TODO(jczaja): Move to MKLDNNHandler as common code - std::shared_ptr AcquireDiffDstMemory( - const framework::Tensor* diffdst) { - const T* ptr = diffdst->data(); - return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(), - to_void_cast(ptr), - "@diff_dst_mem_p"); - } - - // TODO(jczaja): Move to MKLDNNHandler as common code - std::shared_ptr AcquireDiffSrcMemory( - framework::Tensor* diffsrc) { - T* ptr = diffsrc->mutable_data( - place_, bwd_pd_->diff_src_primitive_desc().get_size()); - return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(), - ptr, "@diff_src_mem_p"); - } - std::shared_ptr AcquireActivation( std::shared_ptr dst_memory_p, std::shared_ptr src_memory_p) { /*Generate key*/ - auto prim_key = key_ + "@eltwise_p"; + auto prim_key = this->key_ + "@eltwise_p"; auto eltwise_p = std::static_pointer_cast( - dev_ctx_.GetBlob(prim_key)); + this->dev_ctx_.GetBlob(prim_key)); if (eltwise_p == nullptr) { eltwise_p = std::make_shared( - *fwd_pd_, *(src_memory_p), *(dst_memory_p)); - dev_ctx_.SetBlob(prim_key, eltwise_p); + *this->fwd_pd_, *(src_memory_p), *(dst_memory_p)); + this->dev_ctx_.SetBlob(prim_key, eltwise_p); } return eltwise_p; @@ -364,15 +410,15 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { std::shared_ptr diff_dst_memory_p, std::shared_ptr src_memory_p) { /*Generate key*/ - auto prim_key = key_ + "@eltwise_bwd_p"; + auto prim_key = this->key_ + "@eltwise_bwd_p"; auto eltwise_bwd_p = std::static_pointer_cast( - dev_ctx_.GetBlob(prim_key)); + this->dev_ctx_.GetBlob(prim_key)); if (eltwise_bwd_p == nullptr) { eltwise_bwd_p = std::make_shared( - *bwd_pd_, *(src_memory_p), *(diff_dst_memory_p), + *this->bwd_pd_, *(src_memory_p), *(diff_dst_memory_p), *(diff_src_memory_p)); - dev_ctx_.SetBlob(prim_key, eltwise_bwd_p); + this->dev_ctx_.SetBlob(prim_key, eltwise_bwd_p); } return eltwise_bwd_p; @@ -387,26 +433,27 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { // Activation PD has to be passed to Grad op that // may be executed by diffrent thread, hence // for that one we use key that does not contain TID - const std::string key_activation_pd = key_common_ + "@activation_pd"; - fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_activation_pd)); - if (fwd_pd_ == nullptr) { + const std::string key_activation_pd = this->key_common_ + "@activation_pd"; + this->fwd_pd_ = + std::static_pointer_cast( + this->dev_ctx_.GetBlob(key_activation_pd)); + if (this->fwd_pd_ == nullptr) { static std::mutex acquire_barrier; std::lock_guard block_threads_until_finish_this_job( acquire_barrier); - fwd_pd_ = + this->fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_activation_pd)); - if (fwd_pd_ == nullptr) { + this->dev_ctx_.GetBlob(key_activation_pd)); + if (this->fwd_pd_ == nullptr) { auto md = platform::MKLDNNMemDesc( dims, platform::MKLDNNGetDataType(), fmt); auto activation_desc = mkldnn::eltwise_forward::desc( prop_kind, algorithm, md, alpha, beta); - fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc( - activation_desc, engine_)); - dev_ctx_.SetBlob(key_activation_pd, fwd_pd_); + this->fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc( + activation_desc, this->engine_)); + this->dev_ctx_.SetBlob(key_activation_pd, this->fwd_pd_); } } } @@ -415,17 +462,18 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { mkldnn::algorithm algorithm, const std::vector& dims, const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_fmt, float alpha, float beta) { - const std::string key_activation_pd = key_common_ + "@activation_pd"; - const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd"; - bwd_pd_ = + const std::string key_activation_pd = this->key_common_ + "@activation_pd"; + const std::string key_activation_bwd_pd = this->key_ + "@activation_bwd_pd"; + this->bwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_activation_bwd_pd)); - if (bwd_pd_ == nullptr) { - fwd_pd_ = + this->dev_ctx_.GetBlob(key_activation_bwd_pd)); + if (this->bwd_pd_ == nullptr) { + this->fwd_pd_ = std::static_pointer_cast( - dev_ctx_.GetBlob(key_activation_pd)); + this->dev_ctx_.GetBlob(key_activation_pd)); // PD from FWD op has to exist. - PADDLE_ENFORCE_NOT_NULL(fwd_pd_, "Eltwise MKL-DNN not found in cache!"); + PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_, + "Eltwise MKL-DNN not found in cache!"); auto diff_dst_md = platform::MKLDNNMemDesc( dims, platform::MKLDNNGetDataType(), diff_fmt); @@ -434,16 +482,11 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { auto backward_desc = mkldnn::eltwise_backward::desc( algorithm, diff_dst_md, src_md, alpha, beta); - bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc( - backward_desc, engine_, *fwd_pd_)); - dev_ctx_.SetBlob(key_activation_bwd_pd, bwd_pd_); + this->bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc( + backward_desc, this->engine_, *this->fwd_pd_)); + this->dev_ctx_.SetBlob(key_activation_bwd_pd, this->bwd_pd_); } } - - private: - platform::Place place_; - std::shared_ptr fwd_pd_; - std::shared_ptr bwd_pd_; }; class LRNMKLDNNHandler : public MKLDNNHandler {