|
|
|
@ -257,65 +257,94 @@ class SumMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
std::shared_ptr<mkldnn::sum::primitive_desc> sum_pd_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ActivationMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
public:
|
|
|
|
|
ActivationMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
mkldnn::engine engine, const std::string& base_key)
|
|
|
|
|
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
|
|
|
|
|
ActivationMKLDNNHandler(const std::vector<int>& dims,
|
|
|
|
|
mkldnn::algorithm algorithm, float alpha, float beta,
|
|
|
|
|
const MKLDNNMemoryFormat fmt, bool is_test,
|
|
|
|
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
platform::Place cpu_place,
|
|
|
|
|
const std::string& unique_name)
|
|
|
|
|
|
|
|
|
|
: platform::MKLDNNHandler(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(),
|
|
|
|
|
platform::ActivationMKLDNNHandler<T>::GetHash(
|
|
|
|
|
dims, algorithm, fmt, alpha, beta, unique_name)),
|
|
|
|
|
place_(cpu_place),
|
|
|
|
|
fwd_pd_(nullptr),
|
|
|
|
|
bwd_pd_(nullptr) {
|
|
|
|
|
AcquireActivationPrimitiveDescriptor(
|
|
|
|
|
is_test ? mkldnn::prop_kind::forward_inference
|
|
|
|
|
: mkldnn::prop_kind::forward_training,
|
|
|
|
|
algorithm, dims, fmt, alpha, beta);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ActivationMKLDNNHandler(const std::vector<int>& dims,
|
|
|
|
|
mkldnn::algorithm algorithm, float alpha, float beta,
|
|
|
|
|
const MKLDNNMemoryFormat fmt,
|
|
|
|
|
const MKLDNNMemoryFormat diff_fmt,
|
|
|
|
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
platform::Place cpu_place,
|
|
|
|
|
const std::string& unique_name)
|
|
|
|
|
|
|
|
|
|
: platform::MKLDNNHandler(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(),
|
|
|
|
|
platform::ActivationMKLDNNHandler<T>::GetHash(
|
|
|
|
|
dims, algorithm, fmt, alpha, beta, unique_name)),
|
|
|
|
|
place_(cpu_place),
|
|
|
|
|
fwd_pd_(nullptr),
|
|
|
|
|
bwd_pd_(nullptr) {
|
|
|
|
|
AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
|
|
|
|
|
algorithm, dims, fmt, alpha, beta);
|
|
|
|
|
AcquireActivationBackwardPrimitiveDescriptor(algorithm, dims, fmt, diff_fmt,
|
|
|
|
|
alpha, beta);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this
|
|
|
|
|
// function
|
|
|
|
|
// should be moved as well eg. ActivationMKLDNNHandler ->
|
|
|
|
|
// MKLDNNHandler<activation_>
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
|
|
|
|
|
const framework::Tensor* input) {
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(),
|
|
|
|
|
to_void_cast<T>(input_data),
|
|
|
|
|
"@src_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc>
|
|
|
|
|
AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind,
|
|
|
|
|
mkldnn::algorithm algorithm,
|
|
|
|
|
const mkldnn::memory::desc& md,
|
|
|
|
|
float alpha, float beta) {
|
|
|
|
|
// Activation PD has to be passed to Grad op that
|
|
|
|
|
// may be executed by diffrent thread, hence
|
|
|
|
|
// for that one we use key that does not contain TID
|
|
|
|
|
const std::string key_activation_pd = key_common_ + "@activation_pd";
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_activation_pd));
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
static std::mutex acquire_barrier;
|
|
|
|
|
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
|
|
|
|
|
acquire_barrier);
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
|
|
|
|
|
const framework::Tensor* input) {
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(bwd_pd_->src_primitive_desc(),
|
|
|
|
|
to_void_cast<T>(input_data),
|
|
|
|
|
"@bwd-src_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fwd_pd_ =
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_activation_pd));
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
auto activation_desc = mkldnn::eltwise_forward::desc(
|
|
|
|
|
prop_kind, algorithm, md, alpha, beta);
|
|
|
|
|
// TODO(jczaja): Move to MKLDNNHandler as common code
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
|
|
|
|
|
T* ptr = output->mutable_data<T>(place_,
|
|
|
|
|
fwd_pd_->dst_primitive_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
|
|
|
|
|
"@dst_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc(
|
|
|
|
|
activation_desc, engine_));
|
|
|
|
|
dev_ctx_.SetBlob(key_activation_pd, fwd_pd_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return fwd_pd_;
|
|
|
|
|
// TODO(jczaja): Move to MKLDNNHandler as common code
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
|
|
|
|
|
const framework::Tensor* diffdst) {
|
|
|
|
|
const T* ptr = diffdst->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(),
|
|
|
|
|
to_void_cast<T>(ptr),
|
|
|
|
|
"@diff_dst_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc>
|
|
|
|
|
AcquireActivationBackwardPrimitiveDescriptor(
|
|
|
|
|
mkldnn::algorithm algorithm, const mkldnn::memory::desc& diff_dst_md,
|
|
|
|
|
const mkldnn::memory::desc& src_md, float alpha, float beta) {
|
|
|
|
|
const std::string key_activation_pd = key_common_ + "@activation_pd";
|
|
|
|
|
const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd";
|
|
|
|
|
bwd_pd_ =
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_activation_bwd_pd));
|
|
|
|
|
if (bwd_pd_ == nullptr) {
|
|
|
|
|
fwd_pd_ =
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_activation_pd));
|
|
|
|
|
// PD from FWD op has to exist.
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(fwd_pd_, "Eltwise MKL-DNN not found in cache!");
|
|
|
|
|
auto backward_desc = mkldnn::eltwise_backward::desc(
|
|
|
|
|
algorithm, diff_dst_md, src_md, alpha, beta);
|
|
|
|
|
bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
|
|
|
|
|
backward_desc, engine_, *fwd_pd_));
|
|
|
|
|
dev_ctx_.SetBlob(key_activation_bwd_pd, bwd_pd_);
|
|
|
|
|
}
|
|
|
|
|
return bwd_pd_;
|
|
|
|
|
// TODO(jczaja): Move to MKLDNNHandler as common code
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
|
|
|
|
|
framework::Tensor* diffsrc) {
|
|
|
|
|
T* ptr = diffsrc->mutable_data<T>(
|
|
|
|
|
place_, bwd_pd_->diff_src_primitive_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
|
|
|
|
|
ptr, "@diff_src_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::eltwise_forward> AcquireActivation(
|
|
|
|
@ -335,20 +364,6 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
return eltwise_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(
|
|
|
|
|
framework::Tensor* output, platform::Place place) {
|
|
|
|
|
T* ptr = output->mutable_data<T>(place,
|
|
|
|
|
fwd_pd_->dst_primitive_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
|
|
|
|
|
"@dst_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromPrimitive(void* ptr) {
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
|
|
|
|
|
ptr, "@diff_src_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::eltwise_backward> AcquireActivationBackward(
|
|
|
|
|
std::shared_ptr<mkldnn::memory> diff_src_memory_p,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
|
|
|
|
@ -383,7 +398,70 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
return key;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind,
|
|
|
|
|
mkldnn::algorithm algorithm,
|
|
|
|
|
const std::vector<int>& dims,
|
|
|
|
|
const MKLDNNMemoryFormat fmt,
|
|
|
|
|
float alpha, float beta) {
|
|
|
|
|
// Activation PD has to be passed to Grad op that
|
|
|
|
|
// may be executed by diffrent thread, hence
|
|
|
|
|
// for that one we use key that does not contain TID
|
|
|
|
|
const std::string key_activation_pd = key_common_ + "@activation_pd";
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_activation_pd));
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
static std::mutex acquire_barrier;
|
|
|
|
|
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
|
|
|
|
|
acquire_barrier);
|
|
|
|
|
|
|
|
|
|
fwd_pd_ =
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_activation_pd));
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
auto md = platform::MKLDNNMemDesc(
|
|
|
|
|
dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
auto activation_desc = mkldnn::eltwise_forward::desc(
|
|
|
|
|
prop_kind, algorithm, md, alpha, beta);
|
|
|
|
|
|
|
|
|
|
fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc(
|
|
|
|
|
activation_desc, engine_));
|
|
|
|
|
dev_ctx_.SetBlob(key_activation_pd, fwd_pd_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AcquireActivationBackwardPrimitiveDescriptor(
|
|
|
|
|
mkldnn::algorithm algorithm, const std::vector<int>& dims,
|
|
|
|
|
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_fmt,
|
|
|
|
|
float alpha, float beta) {
|
|
|
|
|
const std::string key_activation_pd = key_common_ + "@activation_pd";
|
|
|
|
|
const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd";
|
|
|
|
|
bwd_pd_ =
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_activation_bwd_pd));
|
|
|
|
|
if (bwd_pd_ == nullptr) {
|
|
|
|
|
fwd_pd_ =
|
|
|
|
|
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_activation_pd));
|
|
|
|
|
// PD from FWD op has to exist.
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(fwd_pd_, "Eltwise MKL-DNN not found in cache!");
|
|
|
|
|
|
|
|
|
|
auto diff_dst_md = platform::MKLDNNMemDesc(
|
|
|
|
|
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
|
|
|
|
|
auto src_md =
|
|
|
|
|
platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
|
|
|
|
|
auto backward_desc = mkldnn::eltwise_backward::desc(
|
|
|
|
|
algorithm, diff_dst_md, src_md, alpha, beta);
|
|
|
|
|
bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
|
|
|
|
|
backward_desc, engine_, *fwd_pd_));
|
|
|
|
|
dev_ctx_.SetBlob(key_activation_bwd_pd, bwd_pd_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd_;
|
|
|
|
|
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd_;
|
|
|
|
|
};
|
|
|
|
|