|
|
|
@ -33,17 +33,18 @@ using mkldnn::stream;
|
|
|
|
|
using platform::to_void_cast;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
class SoftmaxMKLDNNHandler
|
|
|
|
|
: public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
|
|
|
|
|
mkldnn::softmax_backward> {
|
|
|
|
|
public:
|
|
|
|
|
SoftmaxMKLDNNHandler(const std::vector<int>& dims,
|
|
|
|
|
const MKLDNNMemoryFormat fmt,
|
|
|
|
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
platform::Place cpu_place, const std::string& uniq_name)
|
|
|
|
|
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
|
|
|
|
|
platform::CreateKey(dims, uniq_name)),
|
|
|
|
|
place_(cpu_place),
|
|
|
|
|
fwd_pd_(nullptr),
|
|
|
|
|
bwd_pd_(nullptr) {
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
|
|
|
|
|
mkldnn::softmax_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dims, uniq_name)) {
|
|
|
|
|
this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -52,11 +53,10 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
const MKLDNNMemoryFormat diff_fmt,
|
|
|
|
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
|
|
|
|
platform::Place cpu_place, const std::string& uniq_name)
|
|
|
|
|
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
|
|
|
|
|
platform::CreateKey(dims, uniq_name)),
|
|
|
|
|
place_(cpu_place),
|
|
|
|
|
fwd_pd_(nullptr),
|
|
|
|
|
bwd_pd_(nullptr) {
|
|
|
|
|
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
|
|
|
|
|
mkldnn::softmax_backward>(
|
|
|
|
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
|
|
|
|
platform::CreateKey(dims, uniq_name)) {
|
|
|
|
|
// If we are in Grad operatgor then update a key with BWD suffix to
|
|
|
|
|
// distinguish from FWD memory primitives
|
|
|
|
|
// Key_common will allow to access FWD_PD from cache
|
|
|
|
@ -64,58 +64,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
this->AcquireSoftmaxBackwardPrimitiveDescriptor(dims, fmt, diff_fmt);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function
|
|
|
|
|
// should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler<softmax_>
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(const Tensor* input) {
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(),
|
|
|
|
|
to_void_cast<T>(input_data),
|
|
|
|
|
"@src_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(jczaja): Move to MKLDNNHandler as common code
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
|
|
|
|
|
T* ptr = output->mutable_data<T>(place_,
|
|
|
|
|
fwd_pd_->dst_primitive_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
|
|
|
|
|
"@dst_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDstMemory(const Tensor* output) {
|
|
|
|
|
const T* output_data = output->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_primitive_desc(),
|
|
|
|
|
to_void_cast<T>(output_data),
|
|
|
|
|
"@bwd-dst_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(const Tensor* diffdst) {
|
|
|
|
|
const T* ptr = diffdst->data<T>();
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(),
|
|
|
|
|
to_void_cast<T>(ptr),
|
|
|
|
|
"@diff_dst_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
|
|
|
|
|
framework::Tensor* diffsrc) {
|
|
|
|
|
T* ptr = diffsrc->mutable_data<T>(
|
|
|
|
|
place_, bwd_pd_->diff_src_primitive_desc().get_size());
|
|
|
|
|
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
|
|
|
|
|
ptr, "@diff_src_mem_p");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
|
|
|
|
|
std::shared_ptr<mkldnn::memory> dst_memory_p,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> src_memory_p) {
|
|
|
|
|
/*Generate key*/
|
|
|
|
|
auto prim_key = key_ + "@softmax_p";
|
|
|
|
|
auto prim_key = this->key_ + "@softmax_p";
|
|
|
|
|
|
|
|
|
|
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
|
|
|
|
|
dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
this->dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
if (softmax_p == nullptr) {
|
|
|
|
|
softmax_p = std::make_shared<mkldnn::softmax_forward>(
|
|
|
|
|
*fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
|
|
|
|
|
*this->fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
|
|
|
|
|
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, softmax_p);
|
|
|
|
|
this->dev_ctx_.SetBlob(prim_key, softmax_p);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return softmax_p;
|
|
|
|
@ -125,13 +86,14 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
std::shared_ptr<mkldnn::memory> dst_memory_p,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
|
|
|
|
|
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
|
|
|
|
|
auto prim_key = key_ + "@softmax_bwd_p";
|
|
|
|
|
auto prim_key = this->key_ + "@softmax_bwd_p";
|
|
|
|
|
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
|
|
|
|
|
dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
this->dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
if (softmax_bwd_p == nullptr) {
|
|
|
|
|
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
|
|
|
|
|
*bwd_pd_, *dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p);
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
|
|
|
|
|
*this->bwd_pd_, *dst_memory_p, *diff_dst_memory_p,
|
|
|
|
|
*diff_src_memory_p);
|
|
|
|
|
this->dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return softmax_bwd_p;
|
|
|
|
@ -143,17 +105,17 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
// Softmax PD has to be passed to Grad op that
|
|
|
|
|
// may be executed by diffrent thread, hence
|
|
|
|
|
// for that one we use key that does not contain TID
|
|
|
|
|
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
|
|
|
|
|
const std::string key_softmax_pd = this->key_common_ + "@softmax_pd";
|
|
|
|
|
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_softmax_pd));
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
this->fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
|
|
|
|
|
this->dev_ctx_.GetBlob(key_softmax_pd));
|
|
|
|
|
if (this->fwd_pd_ == nullptr) {
|
|
|
|
|
static std::mutex acquire_barrier;
|
|
|
|
|
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
|
|
|
|
|
acquire_barrier);
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_softmax_pd));
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
|
this->fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
|
|
|
|
|
this->dev_ctx_.GetBlob(key_softmax_pd));
|
|
|
|
|
if (this->fwd_pd_ == nullptr) {
|
|
|
|
|
// TODO(jczaja): Make it working along chosen axis and for
|
|
|
|
|
// forward_training
|
|
|
|
|
// Normalization is made after innermost dimension eg. C out of NC
|
|
|
|
@ -161,9 +123,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
auto softmax_desc =
|
|
|
|
|
softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/);
|
|
|
|
|
fwd_pd_.reset(
|
|
|
|
|
new softmax_forward::primitive_desc(softmax_desc, engine_));
|
|
|
|
|
dev_ctx_.SetBlob(key_softmax_pd, fwd_pd_);
|
|
|
|
|
this->fwd_pd_.reset(
|
|
|
|
|
new softmax_forward::primitive_desc(softmax_desc, this->engine_));
|
|
|
|
|
this->dev_ctx_.SetBlob(key_softmax_pd, this->fwd_pd_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -172,12 +134,12 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
const std::vector<int>& dims, const mkldnn::memory::format fmt,
|
|
|
|
|
const mkldnn::memory::format diff_fmt) {
|
|
|
|
|
// Fwd_PD_ has to exists when to create BWD_PD_
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(fwd_pd_);
|
|
|
|
|
const std::string key_bwd_pd = key_ + "@softmax_bwd_pd";
|
|
|
|
|
bwd_pd_ =
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_);
|
|
|
|
|
const std::string key_bwd_pd = this->key_ + "@softmax_bwd_pd";
|
|
|
|
|
this->bwd_pd_ =
|
|
|
|
|
std::static_pointer_cast<mkldnn::softmax_backward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_bwd_pd));
|
|
|
|
|
if (bwd_pd_ == nullptr) {
|
|
|
|
|
this->dev_ctx_.GetBlob(key_bwd_pd));
|
|
|
|
|
if (this->bwd_pd_ == nullptr) {
|
|
|
|
|
auto data_softmax_md =
|
|
|
|
|
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
|
|
|
|
auto diff_softmax_md = mkldnn::memory::desc(
|
|
|
|
@ -185,16 +147,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
// TODO(jczaja): Add support for other axes
|
|
|
|
|
auto backward_desc = softmax_backward::desc(
|
|
|
|
|
diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
|
|
|
|
|
bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc(
|
|
|
|
|
backward_desc, engine_, *fwd_pd_));
|
|
|
|
|
dev_ctx_.SetBlob(key_bwd_pd, bwd_pd_);
|
|
|
|
|
this->bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc(
|
|
|
|
|
backward_desc, this->engine_, *this->fwd_pd_));
|
|
|
|
|
this->dev_ctx_.SetBlob(key_bwd_pd, this->bwd_pd_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd_;
|
|
|
|
|
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> bwd_pd_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|