From bc902044a47920347a3e005f995eb9b68ded57b4 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Fri, 27 Nov 2020 05:55:57 +0100 Subject: [PATCH] Fixes mkldnn dygraph learning rate scheduler crashes (#28988) --- paddle/fluid/framework/operator.cc | 17 +++++++++++++++++ paddle/fluid/framework/operator.h | 5 +++++ paddle/fluid/operators/activation_op.cc | 2 +- paddle/fluid/operators/addmm_op.cc | 2 +- paddle/fluid/operators/batch_norm_op.cc | 6 ++---- paddle/fluid/operators/concat_op.cc | 2 +- paddle/fluid/operators/conv_op.cc | 5 ++--- paddle/fluid/operators/conv_transpose_op.cc | 2 +- paddle/fluid/operators/data_norm_op.cc | 4 ++-- .../fluid/operators/detection/prior_box_op.cc | 2 +- .../operators/elementwise/elementwise_div_op.h | 2 +- .../operators/elementwise/elementwise_mul_op.h | 2 +- .../operators/elementwise/elementwise_op.h | 11 +++++------ paddle/fluid/operators/fused/fusion_gru_op.cc | 2 +- paddle/fluid/operators/gaussian_random_op.cc | 2 +- paddle/fluid/operators/gelu_op.cc | 4 ++-- paddle/fluid/operators/layer_norm_op.cc | 2 +- paddle/fluid/operators/lrn_op.cc | 4 ++-- paddle/fluid/operators/matmul_op.cc | 2 +- paddle/fluid/operators/mul_op.cc | 2 +- paddle/fluid/operators/pool_op.cc | 4 ++-- paddle/fluid/operators/softmax_op.cc | 4 ++-- paddle/fluid/operators/sum_op.cc | 2 +- paddle/fluid/operators/transpose_op.cc | 8 ++++---- paddle/fluid/platform/mkldnn_helper.h | 5 ----- 25 files changed, 58 insertions(+), 45 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/activation_op.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 21fc293e84..026c1092eb 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1007,6 +1007,23 @@ static void CheckTensorNANOrInf(const std::string& op_type, op_type, name)); } +bool OperatorWithKernel::SupportsMKLDNN() const { + auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); + return std::any_of(op_kernels.begin(), op_kernels.end(), + [](OpKernelMap::const_reference kern_pair) { + return platform::is_cpu_place(kern_pair.first.place_) && + kern_pair.first.library_type_ == + LibraryType::kMKLDNN; + }); +} + +bool OperatorWithKernel::CanMKLDNNBeUsed( + const framework::ExecutionContext& ctx) const { + bool use_mkldnn_ctx = + ctx.Attr("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace()); + return use_mkldnn_ctx && this->SupportsMKLDNN(); +} + void OperatorWithKernel::RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const { diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index d493f350e6..d5107ef5ca 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -156,6 +156,8 @@ class OperatorBase { virtual bool SupportGPU() const { return false; } + virtual bool SupportsMKLDNN() const { return false; } + const std::string& Type() const { return type_; } bool HasAttr(const std::string& name) const { return attrs_.count(name); } @@ -490,6 +492,9 @@ class OperatorWithKernel : public OperatorBase { return platform::is_gpu_place(kern_pair.first.place_); }); } + bool SupportsMKLDNN() const override; + + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) const; virtual void InferShape(InferShapeContext* ctx) const = 0; diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc old mode 100755 new mode 100644 index 40951d5960..26b4ed71e0 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -106,7 +106,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, #ifdef PADDLE_WITH_MKLDNN auto it = oper.Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() && - platform::CanMKLDNNBeUsed(ctx)) { + oper.CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/addmm_op.cc b/paddle/fluid/operators/addmm_op.cc index f6e6856c61..f5b35cbd21 100644 --- a/paddle/fluid/operators/addmm_op.cc +++ b/paddle/fluid/operators/addmm_op.cc @@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 370ba8619f..f74aa259e8 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -157,8 +157,7 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -527,8 +526,7 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 7937e432d2..0b3697156d 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel { "All Inputs of Concat OP are Empty!")); } #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 76ff1084fa..72355c7d3a 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -155,8 +155,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } #endif #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; customized_type_value = @@ -565,7 +564,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { const std::string data_format = ctx.Attr("data_format"); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 7e0e77214c..6c48448555 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -193,7 +193,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 45e77a99e6..7dc1e23207 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel { framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -486,7 +486,7 @@ class DataNormGradOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 0d293bb964..ef6332b641 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; auto input_image_type = ctx.Input("Image")->type(); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 5ac3ffe225..1d016fba34 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -163,7 +163,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index e4d3ea6d72..49456149c2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -33,7 +33,7 @@ class ElementwiseMulOp : public ElementwiseOp { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index ece6af1b5a..bbb240efae 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -108,7 +108,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -265,9 +265,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { return (ctx.Input("X")->dims() == ctx.Input("Y")->dims()); }; - if (platform::CanMKLDNNBeUsed(ctx) && - (ctx.Type() != "elementwise_add_grad" || - CanMKLDNNElementwiseAddGradBeUsed())) { + if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" || + CanMKLDNNElementwiseAddGradBeUsed())) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -304,7 +303,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -343,7 +342,7 @@ class ElementwiseOpDoubleGradWithoutDXDY } #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index e3776a80b3..f5904039d4 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -133,7 +133,7 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType( framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index fd2f48265c..840975f754 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -115,7 +115,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/gelu_op.cc b/paddle/fluid/operators/gelu_op.cc index 9ca0d30362..6c33b05cac 100644 --- a/paddle/fluid/operators/gelu_op.cc +++ b/paddle/fluid/operators/gelu_op.cc @@ -49,7 +49,7 @@ class GeluOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) { + it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -89,7 +89,7 @@ class GeluGradOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) { + it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 79e3d3b90a..6f83a667a5 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -104,7 +104,7 @@ class LayerNormOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index fc9d61eb75..2d4123ccbd 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -201,7 +201,7 @@ class LRNOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -341,7 +341,7 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 129298edaf..639a6991a4 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -659,7 +659,7 @@ class MatMulOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN using mkldnn::memory; - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index b3afba1e4f..9d6c52b98a 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -106,7 +106,7 @@ class MulOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 5b0980a985..b78ced8eee 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -157,7 +157,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -213,7 +213,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index ff25f19110..ff750ab47a 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -72,7 +72,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -196,7 +196,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index faade79091..57fa92b199 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -147,7 +147,7 @@ class SumOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx) && + this->CanMKLDNNBeUsed(ctx) && (static_cast(dtype) == framework::proto::VarType::FP32 || static_cast(dtype) == diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 0e870937ec..a098327ab2 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -88,7 +88,7 @@ class TransposeOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -186,7 +186,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -233,7 +233,7 @@ class Transpose2Op : public TransposeOp { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; using framework::proto::VarType; @@ -298,7 +298,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 34f5759e4c..797ff42f3c 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -134,11 +134,6 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector& dims, return mkldnn::memory::desc({dims}, data_type, format); } -inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { - bool use_mkldnn = ctx.Attr("use_mkldnn"); - return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); -} - inline void ClearMKLDNNCache(const platform::Place& place) { // Clear mkl-dnn cache, if (platform::is_cpu_place(place)) {