[oneDNN] cache cosmetics improvement (#25576)

fix_copy_if_different
Jacek Czaja 5 years ago committed by GitHub
parent 1a5d3defb1
commit 7dbc441eab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -943,7 +943,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@forward_pd";
const std::string key_conv_pd = key + "@fwd_pd";
std::vector<primitive> pipeline;
// Create user memory descriptors

@ -54,7 +54,7 @@ class MKLDNNHandlerT {
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
const std::string key_p = key_ + "@forward_p";
const std::string key_p = key_ + "@fwd_p";
auto forward_p =
std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p));
if (forward_p == nullptr) {
@ -65,7 +65,7 @@ class MKLDNNHandlerT {
}
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
const std::string key_p = key_ + "@backward_p";
const std::string key_p = key_ + "@bwd_p";
auto backward_p =
std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
@ -112,11 +112,11 @@ class MKLDNNHandlerT {
protected:
bool isCached() {
const std::string key_pd = key_common_ + "@forward_pd";
const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
const std::string key_p = key_ + "@forward_p";
const std::string key_p = key_ + "@fwd_p";
return (dev_ctx_.GetBlob(key_p) != nullptr);
}
@ -129,7 +129,7 @@ class MKLDNNHandlerT {
// Forward PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_pd = key_common_ + "@forward_pd";
const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
@ -169,13 +169,13 @@ class MKLDNNHandlerT {
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
const std::string key_fwd_pd = key_common_ + "@forward_pd";
const std::string key_fwd_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fwd_pd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed.", key_fwd_pd));
const std::string key_pd = key_ + "@backward_pd";
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {

Loading…
Cancel
Save