|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
#include <memory>
|
|
|
|
#include <memory>
|
|
|
|
#include <sstream>
|
|
|
|
#include <sstream>
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
#include <utility>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include "boost/optional.hpp"
|
|
|
|
#include "boost/optional.hpp"
|
|
|
|
#include "paddle/fluid/framework/data_layout_transform.h"
|
|
|
|
#include "paddle/fluid/framework/data_layout_transform.h"
|
|
|
@ -48,6 +49,32 @@ class MKLDNNHandlerT {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
|
|
|
std::shared_ptr<TForward> AcquireForwardPrimitive(Args&&... args) {
|
|
|
|
|
|
|
|
const std::string key_p = key_ + "@forward_p";
|
|
|
|
|
|
|
|
auto forward_p =
|
|
|
|
|
|
|
|
std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p));
|
|
|
|
|
|
|
|
if (forward_p == nullptr) {
|
|
|
|
|
|
|
|
forward_p =
|
|
|
|
|
|
|
|
std::make_shared<TForward>(*fwd_pd_, std::forward<Args>(args)...);
|
|
|
|
|
|
|
|
dev_ctx_.SetBlob(key_p, forward_p);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return forward_p;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
|
|
|
std::shared_ptr<TBackward> AcquireBackwardPrimitive(Args&&... args) {
|
|
|
|
|
|
|
|
const std::string key_p = key_ + "@backward_p";
|
|
|
|
|
|
|
|
auto backward_p =
|
|
|
|
|
|
|
|
std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p));
|
|
|
|
|
|
|
|
if (backward_p == nullptr) {
|
|
|
|
|
|
|
|
backward_p =
|
|
|
|
|
|
|
|
std::make_shared<TBackward>(*bwd_pd_, std::forward<Args>(args)...);
|
|
|
|
|
|
|
|
dev_ctx_.SetBlob(key_p, backward_p);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return backward_p;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
|
|
|
|
const framework::Tensor* input) {
|
|
|
|
const framework::Tensor* input) {
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
@ -87,6 +114,44 @@ class MKLDNNHandlerT {
|
|
|
|
ptr, "@diff_src_mem_p");
|
|
|
|
ptr, "@diff_src_mem_p");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
|
|
|
void AcquireForwardPrimitiveDescriptor(Args&&... args) {
|
|
|
|
|
|
|
|
// Forward PD has to be passed to Grad op that
|
|
|
|
|
|
|
|
// may be executed by diffrent thread, hence
|
|
|
|
|
|
|
|
// for that one we use key that does not contain TID
|
|
|
|
|
|
|
|
const std::string key_pd = key_common_ + "@forward_pd";
|
|
|
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
|
|
|
|
|
|
|
|
dev_ctx_.GetBlob(key_pd));
|
|
|
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
|
|
|
static std::mutex acquire_barrier;
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
|
|
|
|
|
|
|
|
acquire_barrier);
|
|
|
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
|
|
|
|
|
|
|
|
dev_ctx_.GetBlob(key_pd));
|
|
|
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
|
|
|
auto fwd_desc = typename TForward::desc(std::forward<Args>(args)...);
|
|
|
|
|
|
|
|
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(fwd_desc,
|
|
|
|
|
|
|
|
engine_);
|
|
|
|
|
|
|
|
dev_ctx_.SetBlob(key_pd, fwd_pd_);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
|
|
|
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(fwd_pd_);
|
|
|
|
|
|
|
|
const std::string key_pd = key_ + "@backward_pd";
|
|
|
|
|
|
|
|
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
|
|
|
|
|
|
|
|
dev_ctx_.GetBlob(key_pd));
|
|
|
|
|
|
|
|
if (bwd_pd_ == nullptr) {
|
|
|
|
|
|
|
|
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
|
|
|
|
|
|
|
|
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
|
|
|
|
|
|
|
|
bwd_desc, engine_, *fwd_pd_);
|
|
|
|
|
|
|
|
dev_ctx_.SetBlob(key_pd, bwd_pd_);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
|
|
|
|
mkldnn::memory::primitive_desc mdp, void* ptr,
|
|
|
|
mkldnn::memory::primitive_desc mdp, void* ptr,
|
|
|
|
const std::string& suffix) {
|
|
|
|
const std::string& suffix) {
|
|
|
@ -102,7 +167,6 @@ class MKLDNNHandlerT {
|
|
|
|
return mem_p;
|
|
|
|
return mem_p;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
|
|
|
const MKLDNNDeviceContext& dev_ctx_;
|
|
|
|
const MKLDNNDeviceContext& dev_ctx_;
|
|
|
|
mkldnn::engine engine_;
|
|
|
|
mkldnn::engine engine_;
|
|
|
|
platform::Place place_;
|
|
|
|
platform::Place place_;
|
|
|
@ -355,10 +419,12 @@ class ActivationMKLDNNHandler
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
platform::CreateKey(dims, algorithm, fmt, alpha, beta,
|
|
|
|
platform::CreateKey(dims, algorithm, fmt, alpha, beta,
|
|
|
|
unique_name)) {
|
|
|
|
unique_name)) {
|
|
|
|
AcquireActivationPrimitiveDescriptor(
|
|
|
|
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(
|
|
|
|
is_test ? mkldnn::prop_kind::forward_inference
|
|
|
|
is_test ? mkldnn::prop_kind::forward_inference
|
|
|
|
: mkldnn::prop_kind::forward_training,
|
|
|
|
: mkldnn::prop_kind::forward_training,
|
|
|
|
algorithm, dims, fmt, alpha, beta);
|
|
|
|
algorithm, md, alpha, beta);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ActivationMKLDNNHandler(const std::vector<int>& dims,
|
|
|
|
ActivationMKLDNNHandler(const std::vector<int>& dims,
|
|
|
@ -374,10 +440,15 @@ class ActivationMKLDNNHandler
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
platform::CreateKey(dims, algorithm, fmt, alpha, beta,
|
|
|
|
platform::CreateKey(dims, algorithm, fmt, alpha, beta,
|
|
|
|
unique_name)) {
|
|
|
|
unique_name)) {
|
|
|
|
AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
algorithm, dims, fmt, alpha, beta);
|
|
|
|
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
|
|
|
|
AcquireActivationBackwardPrimitiveDescriptor(algorithm, dims, fmt, diff_fmt,
|
|
|
|
auto src_md =
|
|
|
|
alpha, beta);
|
|
|
|
platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
|
|
|
|
|
|
|
|
algorithm, src_md, alpha, beta);
|
|
|
|
|
|
|
|
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
|
|
|
|
|
|
|
|
alpha, beta);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
|
|
|
@ -387,106 +458,6 @@ class ActivationMKLDNNHandler
|
|
|
|
to_void_cast<T>(input_data),
|
|
|
|
to_void_cast<T>(input_data),
|
|
|
|
"@bwd-src_mem_p");
|
|
|
|
"@bwd-src_mem_p");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::eltwise_forward> AcquireActivation(
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> dst_memory_p,
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> src_memory_p) {
|
|
|
|
|
|
|
|
/*Generate key*/
|
|
|
|
|
|
|
|
auto prim_key = this->key_ + "@eltwise_p";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto eltwise_p = std::static_pointer_cast<mkldnn::eltwise_forward>(
|
|
|
|
|
|
|
|
this->dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
|
|
|
if (eltwise_p == nullptr) {
|
|
|
|
|
|
|
|
eltwise_p = std::make_shared<mkldnn::eltwise_forward>(
|
|
|
|
|
|
|
|
*this->fwd_pd_, *(src_memory_p), *(dst_memory_p));
|
|
|
|
|
|
|
|
this->dev_ctx_.SetBlob(prim_key, eltwise_p);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return eltwise_p;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::eltwise_backward> AcquireActivationBackward(
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> diff_src_memory_p,
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> src_memory_p) {
|
|
|
|
|
|
|
|
/*Generate key*/
|
|
|
|
|
|
|
|
auto prim_key = this->key_ + "@eltwise_bwd_p";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto eltwise_bwd_p = std::static_pointer_cast<mkldnn::eltwise_backward>(
|
|
|
|
|
|
|
|
this->dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
|
|
|
if (eltwise_bwd_p == nullptr) {
|
|
|
|
|
|
|
|
eltwise_bwd_p = std::make_shared<mkldnn::eltwise_backward>(
|
|
|
|
|
|
|
|
*this->bwd_pd_, *(src_memory_p), *(diff_dst_memory_p),
|
|
|
|
|
|
|
|
*(diff_src_memory_p));
|
|
|
|
|
|
|
|
this->dev_ctx_.SetBlob(prim_key, eltwise_bwd_p);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return eltwise_bwd_p;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
|
|
|
void AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind,
|
|
|
|
|
|
|
|
mkldnn::algorithm algorithm,
|
|
|
|
|
|
|
|
const std::vector<int>& dims,
|
|
|
|
|
|
|
|
const MKLDNNMemoryFormat fmt,
|
|
|
|
|
|
|
|
float alpha, float beta) {
|
|
|
|
|
|
|
|
// Activation PD has to be passed to Grad op that
|
|
|
|
|
|
|
|
// may be executed by diffrent thread, hence
|
|
|
|
|
|
|
|
// for that one we use key that does not contain TID
|
|
|
|
|
|
|
|
const std::string key_activation_pd = this->key_common_ + "@activation_pd";
|
|
|
|
|
|
|
|
this->fwd_pd_ =
|
|
|
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
|
|
|
this->dev_ctx_.GetBlob(key_activation_pd));
|
|
|
|
|
|
|
|
if (this->fwd_pd_ == nullptr) {
|
|
|
|
|
|
|
|
static std::mutex acquire_barrier;
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
|
|
|
|
|
|
|
|
acquire_barrier);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this->fwd_pd_ =
|
|
|
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
|
|
|
this->dev_ctx_.GetBlob(key_activation_pd));
|
|
|
|
|
|
|
|
if (this->fwd_pd_ == nullptr) {
|
|
|
|
|
|
|
|
auto md = platform::MKLDNNMemDesc(
|
|
|
|
|
|
|
|
dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
|
|
|
auto activation_desc = mkldnn::eltwise_forward::desc(
|
|
|
|
|
|
|
|
prop_kind, algorithm, md, alpha, beta);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this->fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc(
|
|
|
|
|
|
|
|
activation_desc, this->engine_));
|
|
|
|
|
|
|
|
this->dev_ctx_.SetBlob(key_activation_pd, this->fwd_pd_);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void AcquireActivationBackwardPrimitiveDescriptor(
|
|
|
|
|
|
|
|
mkldnn::algorithm algorithm, const std::vector<int>& dims,
|
|
|
|
|
|
|
|
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_fmt,
|
|
|
|
|
|
|
|
float alpha, float beta) {
|
|
|
|
|
|
|
|
const std::string key_activation_pd = this->key_common_ + "@activation_pd";
|
|
|
|
|
|
|
|
const std::string key_activation_bwd_pd = this->key_ + "@activation_bwd_pd";
|
|
|
|
|
|
|
|
this->bwd_pd_ =
|
|
|
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>(
|
|
|
|
|
|
|
|
this->dev_ctx_.GetBlob(key_activation_bwd_pd));
|
|
|
|
|
|
|
|
if (this->bwd_pd_ == nullptr) {
|
|
|
|
|
|
|
|
this->fwd_pd_ =
|
|
|
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
|
|
|
this->dev_ctx_.GetBlob(key_activation_pd));
|
|
|
|
|
|
|
|
// PD from FWD op has to exist.
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_,
|
|
|
|
|
|
|
|
"Eltwise MKL-DNN not found in cache!");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
|
|
|
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
|
|
|
|
|
|
|
|
auto src_md =
|
|
|
|
|
|
|
|
platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto backward_desc = mkldnn::eltwise_backward::desc(
|
|
|
|
|
|
|
|
algorithm, diff_dst_md, src_md, alpha, beta);
|
|
|
|
|
|
|
|
this->bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
|
|
|
|
|
|
|
|
backward_desc, this->engine_, *this->fwd_pd_));
|
|
|
|
|
|
|
|
this->dev_ctx_.SetBlob(key_activation_bwd_pd, this->bwd_pd_);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class LRNMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
class LRNMKLDNNHandler : public MKLDNNHandler {
|
|
|
|