|
|
|
@ -88,7 +88,10 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
ksize.size(), strides.size(), framework::make_ddim(ksize),
|
|
|
|
|
framework::make_ddim(strides));
|
|
|
|
|
|
|
|
|
|
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
|
|
|
|
|
// MKL-DNN Kernels are using NCHW order of dims description
|
|
|
|
|
// so we ignore data_format consideration for MKL-DNN kernel
|
|
|
|
|
const bool channel_last = (this->IsMKLDNNType() == false) &&
|
|
|
|
|
(data_format == "NHWC" || data_format == "NDHWC");
|
|
|
|
|
|
|
|
|
|
// update paddings if "SAME" or global_pooling
|
|
|
|
|
framework::DDim data_dims;
|
|
|
|
@ -146,12 +149,6 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
// TODO(jczaja): Add support for NHWC
|
|
|
|
|
const std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
data_format, "NHWC",
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Pool MKLDNN grad does not support NHWC data format yet"));
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
@ -162,6 +159,28 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
|
|
|
|
|
layout_, library_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType PoolOp::GetKernelTypeForVar(
|
|
|
|
|
const std::string& var_name, const Tensor& tensor,
|
|
|
|
|
const framework::OpKernelType& expected_kernel_type) const {
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
|
|
|
|
|
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
|
|
|
|
|
auto attrs = Attrs();
|
|
|
|
|
auto ar = paddle::framework::AttrReader(attrs);
|
|
|
|
|
const std::string data_format = ar.Get<std::string>("data_format");
|
|
|
|
|
auto dl = framework::StringToDataLayout(data_format);
|
|
|
|
|
// Some models may have intentionally set "AnyLayout" for pool
|
|
|
|
|
// op. Treat this as NCHW (default data_format value)
|
|
|
|
|
if (dl != framework::DataLayout::kAnyLayout) {
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), dl);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
|
|
|
|