|
|
|
@ -151,6 +151,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
// TODO(jczaja): Add support for NHWC
|
|
|
|
|
const std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
PADDLE_ENFORCE_NE(data_format, "NHWC",
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Conv MKLDNN does not support NHWC data format yet"));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
data_format, "NDHWC",
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Conv MKLDNN does not support NDHWC data format yet"));
|
|
|
|
|
library = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout = framework::DataLayout::kMKLDNN;
|
|
|
|
|
customized_type_value =
|
|
|
|
@ -524,6 +533,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
// TODO(jczaja): Add support for NHWC
|
|
|
|
|
const std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
data_format, "NHWC",
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Conv MKLDNN grad does not support NHWC data format yet"));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
data_format, "NDHWC",
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Conv MKLDNN Grad does not support NDHWC data format yet"));
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
|
customized_type_value = kConvMKLDNNFP32;
|
|
|
|
@ -706,14 +725,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
|
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
|
customized_type_value = kConvMKLDNNFP32;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
auto type = framework::OpKernelType(
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
|
|
|
|
|