|
|
|
@ -108,7 +108,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
if (this->CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
|
|
|
|
|
framework::DataLayout::kMKLDNN,
|
|
|
|
|
framework::LibraryType::kMKLDNN);
|
|
|
|
@ -265,9 +265,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx) &&
|
|
|
|
|
(ctx.Type() != "elementwise_add_grad" ||
|
|
|
|
|
CanMKLDNNElementwiseAddGradBeUsed())) {
|
|
|
|
|
if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" ||
|
|
|
|
|
CanMKLDNNElementwiseAddGradBeUsed())) {
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
|
|
|
|
|
framework::DataLayout::kMKLDNN,
|
|
|
|
|
framework::LibraryType::kMKLDNN);
|
|
|
|
@ -304,7 +303,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
|
|
|
|
|
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
if (this->CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
|
|
|
|
|
framework::DataLayout::kMKLDNN,
|
|
|
|
|
framework::LibraryType::kMKLDNN);
|
|
|
|
@ -343,7 +342,7 @@ class ElementwiseOpDoubleGradWithoutDXDY
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
if (this->CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
|
|
|
|
|
framework::DataLayout::kMKLDNN,
|
|
|
|
|
framework::LibraryType::kMKLDNN);
|
|
|
|
|