|
|
|
@ -48,7 +48,11 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
int groups = ctx->Attrs().Get<int>("groups");
|
|
|
|
|
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
|
|
|
|
|
const std::string data_format = ctx->Attrs().Get<std::string>("data_format");
|
|
|
|
|
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
|
|
|
|
|
|
|
|
|
|
// MKL-DNN Kernels are using NCHW order of dims description
|
|
|
|
|
// so we ignore data_format consideration for MKL-DNN kernel
|
|
|
|
|
const bool channel_last = (this->IsMKLDNNType() == false) &&
|
|
|
|
|
(data_format == "NHWC" || data_format == "NDHWC");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dims.size() == 4 || in_dims.size() == 5, true,
|
|
|
|
@ -151,15 +155,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
// TODO(jczaja): Add support for NHWC
|
|
|
|
|
const std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
PADDLE_ENFORCE_NE(data_format, "NHWC",
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Conv MKLDNN does not support NHWC data format yet"));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
data_format, "NDHWC",
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Conv MKLDNN does not support NDHWC data format yet"));
|
|
|
|
|
library = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout = framework::DataLayout::kMKLDNN;
|
|
|
|
|
customized_type_value =
|
|
|
|
@ -197,6 +192,32 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
|
|
|
|
|
return type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType ConvOp::GetKernelTypeForVar(
|
|
|
|
|
const std::string& var_name, const Tensor& tensor,
|
|
|
|
|
const framework::OpKernelType& expected_kernel_type) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
// Only input require reshaping, weights and
|
|
|
|
|
// bias are having shape in NCHW order
|
|
|
|
|
if ((var_name == "Input") &&
|
|
|
|
|
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
|
|
|
|
|
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
|
|
|
|
|
auto attrs = Attrs();
|
|
|
|
|
auto ar = paddle::framework::AttrReader(attrs);
|
|
|
|
|
const std::string data_format = ar.Get<std::string>("data_format");
|
|
|
|
|
auto dl = framework::StringToDataLayout(data_format);
|
|
|
|
|
// Some models may have intentionally set "AnyLayout" for pool
|
|
|
|
|
// op. Treat this as NCHW (default data_format value)
|
|
|
|
|
if (dl != framework::DataLayout::kAnyLayout) {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
expected_kernel_type.data_type_, tensor.place(),
|
|
|
|
|
framework::StringToDataLayout(data_format));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Conv2DOpMaker::Make() {
|
|
|
|
|
AddAttr<bool>("is_test",
|
|
|
|
|
"(bool, default false) Set to true for inference only, false "
|
|
|
|
|