|
|
|
@ -54,7 +54,7 @@ class MKLDNNHandlerT {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<TForward> AcquireForwardPrimitive() {
|
|
|
|
|
const std::string key_p = key_ + "@forward_p";
|
|
|
|
|
const std::string key_p = key_ + "@fwd_p";
|
|
|
|
|
auto forward_p =
|
|
|
|
|
std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p));
|
|
|
|
|
if (forward_p == nullptr) {
|
|
|
|
@ -65,7 +65,7 @@ class MKLDNNHandlerT {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
|
|
|
|
|
const std::string key_p = key_ + "@backward_p";
|
|
|
|
|
const std::string key_p = key_ + "@bwd_p";
|
|
|
|
|
auto backward_p =
|
|
|
|
|
std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p));
|
|
|
|
|
if (backward_p == nullptr) {
|
|
|
|
@ -112,11 +112,11 @@ class MKLDNNHandlerT {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
bool isCached() {
|
|
|
|
|
const std::string key_pd = key_common_ + "@forward_pd";
|
|
|
|
|
const std::string key_pd = key_common_ + "@fwd_pd";
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_pd));
|
|
|
|
|
|
|
|
|
|
const std::string key_p = key_ + "@forward_p";
|
|
|
|
|
const std::string key_p = key_ + "@fwd_p";
|
|
|
|
|
return (dev_ctx_.GetBlob(key_p) != nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -129,7 +129,7 @@ class MKLDNNHandlerT {
|
|
|
|
|
// Forward PD has to be passed to Grad op that
|
|
|
|
|
// may be executed by diffrent thread, hence
|
|
|
|
|
// for that one we use key that does not contain TID
|
|
|
|
|
const std::string key_pd = key_common_ + "@forward_pd";
|
|
|
|
|
const std::string key_pd = key_common_ + "@fwd_pd";
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_pd));
|
|
|
|
|
if (fwd_pd_ == nullptr) {
|
|
|
|
@ -169,13 +169,13 @@ class MKLDNNHandlerT {
|
|
|
|
|
|
|
|
|
|
template <typename... Args>
|
|
|
|
|
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
|
|
|
|
|
const std::string key_fwd_pd = key_common_ + "@forward_pd";
|
|
|
|
|
const std::string key_fwd_pd = key_common_ + "@fwd_pd";
|
|
|
|
|
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_fwd_pd));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
fwd_pd_, platform::errors::Unavailable(
|
|
|
|
|
"Get MKLDNN Forward primitive %s failed.", key_fwd_pd));
|
|
|
|
|
const std::string key_pd = key_ + "@backward_pd";
|
|
|
|
|
const std::string key_pd = key_ + "@bwd_pd";
|
|
|
|
|
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
|
|
|
|
|
dev_ctx_.GetBlob(key_pd));
|
|
|
|
|
if (bwd_pd_ == nullptr) {
|
|
|
|
|