|
|
|
@ -460,141 +460,64 @@ class ActivationMKLDNNHandler
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class LRNMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class LRNMKLDNNHandler
|
|
|
|
|
: public MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward> {
|
|
|
|
|
public:
|
|
|
|
|
LRNMKLDNNHandler(bool is_test, const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
mkldnn::engine engine, const std::string& base_key)
|
|
|
|
|
: platform::MKLDNNHandler(dev_ctx, engine, base_key), is_test_(is_test) {}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::lrn_forward::primitive_desc>
|
|
|
|
|
AcquireLRNPrimitiveDescriptor(const mkldnn::memory::desc& src_md, const int n,
|
|
|
|
|
const float alpha, const float beta,
|
|
|
|
|
const float k) {
|
|
|
|
|
// LRN PD has to be passed to Grad op that
|
|
|
|
|
// may be executed by diffrent thread, hence
|
|
|
|
|
// for that one we use key that does not contain TID
|
|
|
|
|
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_lrn_pd));
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
static std::mutex acquire_barrier;
|
|
|
|
|
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
|
|
|
|
|
acquire_barrier);
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_lrn_pd));
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
auto forward_desc = mkldnn::lrn_forward::desc{
|
|
|
|
|
is_test_ ? mkldnn::prop_kind::forward_inference
|
|
|
|
|
: mkldnn::prop_kind::forward_training,
|
|
|
|
|
mkldnn::lrn_across_channels, src_md, n, alpha, beta, k};
|
|
|
|
|
fwd_pd_.reset(
|
|
|
|
|
new mkldnn::lrn_forward::primitive_desc(forward_desc, engine_));
|
|
|
|
|
dev_ctx_.SetBlob(key_lrn_pd, fwd_pd_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return fwd_pd_;
|
|
|
|
|
}
|
|
|
|
|
LRNMKLDNNHandler(const std::vector<int>& dims, const int n, const float alpha,
|
|
|
|
|
const float beta, const float k,
|
|
|
|
|
const MKLDNNMemoryFormat fmt, bool is_test,
|
|
|
|
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
platform::Place cpu_place, const std::string& unique_name)
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(void) {
|
|
|
|
|
// workspace has to be passed to Grad op that
|
|
|
|
|
// may be executed by diffrent thread, hence
|
|
|
|
|
// for that one we use key that does not contain TID
|
|
|
|
|
auto local_key = key_common_ + "@workspace";
|
|
|
|
|
auto mem_p =
|
|
|
|
|
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
|
|
|
|
|
if (mem_p == nullptr) {
|
|
|
|
|
static std::mutex acquire_barrier;
|
|
|
|
|
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
|
|
|
|
|
acquire_barrier);
|
|
|
|
|
mem_p =
|
|
|
|
|
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
|
|
|
|
|
if (mem_p == nullptr) {
|
|
|
|
|
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_lrn_pd));
|
|
|
|
|
// PD from FWD op has to exist.
|
|
|
|
|
PADDLE_ENFORCE(fwd_pd_ != nullptr,
|
|
|
|
|
"LRN PD MKL-DNN not found in cache!");
|
|
|
|
|
mkldnn::memory::primitive_desc workspace_mpd =
|
|
|
|
|
fwd_pd_->workspace_primitive_desc();
|
|
|
|
|
mem_p = std::make_shared<mkldnn::memory>(workspace_mpd);
|
|
|
|
|
dev_ctx_.SetBlob(local_key, mem_p);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return mem_p;
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dims, n, alpha, beta, k, fmt, unique_name)) {
|
|
|
|
|
auto src_md =
|
|
|
|
|
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(
|
|
|
|
|
is_test ? mkldnn::prop_kind::forward_inference
|
|
|
|
|
: mkldnn::prop_kind::forward_training,
|
|
|
|
|
mkldnn::lrn_across_channels, src_md, n, alpha, beta, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::lrn_forward> AcquireLRN(
|
|
|
|
|
std::shared_ptr<mkldnn::memory> dst_memory,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> src_memory) {
|
|
|
|
|
auto prim_key = key_ + "@lrn_p";
|
|
|
|
|
|
|
|
|
|
auto lrn_p = std::static_pointer_cast<mkldnn::lrn_forward>(
|
|
|
|
|
dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
if (lrn_p == nullptr) {
|
|
|
|
|
if (is_test_) {
|
|
|
|
|
lrn_p = std::make_shared<mkldnn::lrn_forward>(*fwd_pd_, *(src_memory),
|
|
|
|
|
*(dst_memory));
|
|
|
|
|
} else {
|
|
|
|
|
// For training we need to create workspace
|
|
|
|
|
// to store indices from backward
|
|
|
|
|
auto workspace_memory = this->AcquireWorkspaceMemory();
|
|
|
|
|
LRNMKLDNNHandler(const std::vector<int>& dims, const int n, const float alpha,
|
|
|
|
|
const float beta, const float k,
|
|
|
|
|
const MKLDNNMemoryFormat fmt,
|
|
|
|
|
const MKLDNNMemoryFormat diff_fmt,
|
|
|
|
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
platform::Place cpu_place, const std::string& unique_name)
|
|
|
|
|
|
|
|
|
|
lrn_p = std::make_shared<mkldnn::lrn_forward>(
|
|
|
|
|
*fwd_pd_, *src_memory, *workspace_memory, *dst_memory);
|
|
|
|
|
}
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, lrn_p);
|
|
|
|
|
}
|
|
|
|
|
return lrn_p;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::lrn_backward::primitive_desc>
|
|
|
|
|
AcquireLRNBackwardPrimitiveDescriptor(const mkldnn::memory::desc& src_md,
|
|
|
|
|
const mkldnn::memory::desc& diff_md,
|
|
|
|
|
const int n, const float alpha,
|
|
|
|
|
const float beta, const float k) {
|
|
|
|
|
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
|
|
|
|
|
const std::string key_lrn_bwd_pd = key_ + "@lrn_bwd_pd";
|
|
|
|
|
bwd_pd_ = std::static_pointer_cast<mkldnn::lrn_backward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_lrn_bwd_pd));
|
|
|
|
|
if (bwd_pd_ == nullptr) {
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_lrn_pd));
|
|
|
|
|
// PD from FWD op has to exist.
|
|
|
|
|
PADDLE_ENFORCE(fwd_pd_ != nullptr, "LRN MKL-DNN not found in cache!");
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dims, n, alpha, beta, k, fmt, unique_name)) {
|
|
|
|
|
auto src_md =
|
|
|
|
|
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
auto diff_md =
|
|
|
|
|
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
|
|
|
|
|
|
|
|
|
|
auto backward_desc = mkldnn::lrn_backward::desc{
|
|
|
|
|
mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k};
|
|
|
|
|
bwd_pd_.reset(new mkldnn::lrn_backward::primitive_desc(
|
|
|
|
|
backward_desc, engine_, *fwd_pd_));
|
|
|
|
|
dev_ctx_.SetBlob(key_lrn_bwd_pd, bwd_pd_);
|
|
|
|
|
}
|
|
|
|
|
return bwd_pd_;
|
|
|
|
|
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
|
|
|
|
|
mkldnn::lrn_across_channels, src_md,
|
|
|
|
|
n, alpha, beta, k);
|
|
|
|
|
this->AcquireBackwardPrimitiveDescriptor(
|
|
|
|
|
mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::lrn_backward> AcquireLRNBackward(
|
|
|
|
|
std::shared_ptr<mkldnn::memory> src_memory,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> diff_dst_memory,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> workspace,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> diff_src_memory) {
|
|
|
|
|
auto prim_key = key_ + "@lrn_bwd_p";
|
|
|
|
|
|
|
|
|
|
auto lrn_bwd_p = std::static_pointer_cast<mkldnn::lrn_backward>(
|
|
|
|
|
dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
if (lrn_bwd_p == nullptr) {
|
|
|
|
|
lrn_bwd_p = std::make_shared<mkldnn::lrn_backward>(
|
|
|
|
|
*bwd_pd_, *src_memory, *diff_dst_memory, *workspace,
|
|
|
|
|
*diff_src_memory);
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, lrn_bwd_p);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return lrn_bwd_p;
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(
|
|
|
|
|
framework::Tensor* workspace) {
|
|
|
|
|
T* ptr = workspace->mutable_data<T>(
|
|
|
|
|
this->place_, this->fwd_pd_->dst_primitive_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->workspace_primitive_desc(), ptr, "@wrk_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool is_test_;
|
|
|
|
|
std::shared_ptr<mkldnn::lrn_forward::primitive_desc> fwd_pd_;
|
|
|
|
|
std::shared_ptr<mkldnn::lrn_backward::primitive_desc> bwd_pd_;
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory(
|
|
|
|
|
const framework::Tensor* workspace) {
|
|
|
|
|
const T* workspace_data = workspace->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(
|
|
|
|
|
this->fwd_pd_->workspace_primitive_desc(),
|
|
|
|
|
to_void_cast<T>(workspace_data), "@bwd-wrk_mem_p");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class PoolingMKLDNNHandler : public MKLDNNHandler {
|
|
|
|
|