|
|
|
@ -110,19 +110,19 @@ class BatchNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx.Input<Tensor>("Variance")->type()),
|
|
|
|
|
"Variance input should be of float type");
|
|
|
|
|
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::LibraryType library = framework::LibraryType::kPlain;
|
|
|
|
|
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
if (library == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
library = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout = framework::DataLayout::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
|
|
|
|
|
library_);
|
|
|
|
|
library);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -370,19 +370,21 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_THROW("can't find Y@GRAD");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
|
|
|
|
|
framework::LibraryType library = framework::LibraryType::kPlain;
|
|
|
|
|
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
if (library == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
|
library = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout = framework::DataLayout::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|
|
layout_, library_);
|
|
|
|
|
layout, library);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|