[MKL-DNN] Thread-Safety for MKL-DNN reusing Part 1 (#17965)

* - removed is_reusing_

* - Added TID to keys for reusing apart from softmax PD

* - compilation fix

* - Yet another compilation fix

* - Batch Norm and Conv adapted

* - Fix to softmax MT

* - Fixes to MT code of MKL-DNN

* - Lint fixes

test=develop
lite
Jacek Czaja 6 years ago committed by Tao Luo
parent da9143c1cc
commit 84bb45c054

@ -61,20 +61,25 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<batch_norm_fwd::primitive_desc>
AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc,
const mkldnn::engine &engine) {
const std::string key_batch_norm_fwd_pd = key_ + "@bn_fwd_pd";
auto batch_norm_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd == nullptr) {
batch_norm_pd_.reset(
new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine));
dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_);
} else {
batch_norm_pd_ = batch_norm_pd;
is_reusing_ = true;
// BatchNorm PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_batch_norm_fwd_pd = key_common_ + "@bn_fwd_pd";
batch_norm_pd_ = std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
batch_norm_pd_ = std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd_ == nullptr) {
batch_norm_pd_.reset(
new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine));
dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_);
}
}
return batch_norm_pd_;
}
@ -87,9 +92,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
auto batch_norm_p =
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((batch_norm_p != nullptr) || !is_reusing_,
"Fail to find batch norm primitive in device context");
if (batch_norm_p == nullptr) {
if (is_test) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
@ -104,8 +106,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
}
dev_ctx_.SetBlob(prim_key, batch_norm_p);
} else {
is_reusing_ = true;
}
return batch_norm_p;

@ -54,18 +54,24 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<softmax_forward::primitive_desc>
AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc,
const mkldnn::engine& engine) {
const std::string key_softmax_pd = key_ + "@softmax_pd";
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
auto softmax_pd = std::static_pointer_cast<softmax_forward::primitive_desc>(
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (softmax_pd == nullptr) {
softmax_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine));
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_);
} else {
softmax_pd_ = softmax_pd;
is_reusing_ = true;
if (softmax_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (softmax_pd_ == nullptr) {
softmax_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine));
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_);
}
}
return softmax_pd_;
@ -79,15 +85,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax primitive in device context");
if (softmax_p == nullptr) {
softmax_p = std::make_shared<mkldnn::softmax_forward>(
*softmax_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
dev_ctx_.SetBlob(prim_key, softmax_p);
} else {
is_reusing_ = true;
}
return softmax_p;
@ -100,15 +102,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto prim_key = key_ + "@softmax_bwd_p";
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_bwd_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax backward primitive in device context");
if (softmax_bwd_p == nullptr) {
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p,
*diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
} else {
is_reusing_ = true;
}
return softmax_bwd_p;

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save