|
|
|
@ -54,18 +54,24 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
std::shared_ptr<softmax_forward::primitive_desc>
|
|
|
|
|
AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc,
|
|
|
|
|
const mkldnn::engine& engine) {
|
|
|
|
|
const std::string key_softmax_pd = key_ + "@softmax_pd";
|
|
|
|
|
// Softmax PD has to be passed to Grad op that
|
|
|
|
|
// may be executed by diffrent thread, hence
|
|
|
|
|
// for that one we use key that does not contain TID
|
|
|
|
|
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
|
|
|
|
|
|
|
|
|
|
auto softmax_pd = std::static_pointer_cast<softmax_forward::primitive_desc>(
|
|
|
|
|
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_softmax_pd));
|
|
|
|
|
|
|
|
|
|
if (softmax_pd == nullptr) {
|
|
|
|
|
softmax_pd_.reset(
|
|
|
|
|
new softmax_forward::primitive_desc(softmax_desc, engine));
|
|
|
|
|
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_);
|
|
|
|
|
} else {
|
|
|
|
|
softmax_pd_ = softmax_pd;
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
if (softmax_pd_ == nullptr) {
|
|
|
|
|
static std::mutex acquire_barrier;
|
|
|
|
|
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
|
|
|
|
|
acquire_barrier);
|
|
|
|
|
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_softmax_pd));
|
|
|
|
|
if (softmax_pd_ == nullptr) {
|
|
|
|
|
softmax_pd_.reset(
|
|
|
|
|
new softmax_forward::primitive_desc(softmax_desc, engine));
|
|
|
|
|
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return softmax_pd_;
|
|
|
|
@ -79,15 +85,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
|
|
|
|
|
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
|
|
|
|
|
dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
PADDLE_ENFORCE((softmax_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
"Fail to find softmax primitive in device context");
|
|
|
|
|
if (softmax_p == nullptr) {
|
|
|
|
|
softmax_p = std::make_shared<mkldnn::softmax_forward>(
|
|
|
|
|
*softmax_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
|
|
|
|
|
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, softmax_p);
|
|
|
|
|
} else {
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return softmax_p;
|
|
|
|
@ -100,15 +102,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
|
|
|
|
|
auto prim_key = key_ + "@softmax_bwd_p";
|
|
|
|
|
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
|
|
|
|
|
dev_ctx_.GetBlob(prim_key));
|
|
|
|
|
PADDLE_ENFORCE((softmax_bwd_p != nullptr) || (is_reusing_ == false),
|
|
|
|
|
"Fail to find softmax backward primitive in device context");
|
|
|
|
|
if (softmax_bwd_p == nullptr) {
|
|
|
|
|
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
|
|
|
|
|
*softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p,
|
|
|
|
|
*diff_src_memory_p);
|
|
|
|
|
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
|
|
|
|
|
} else {
|
|
|
|
|
is_reusing_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return softmax_bwd_p;
|
|
|
|
|