|
|
|
@ -500,17 +500,17 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
|
|
|
|
|
if (!this->isCached()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
|
|
|
|
|
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
x->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong format set for X tensor"));
|
|
|
|
|
platform::errors::InvalidArgument("Wrong format set for X tensor."));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
y->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong layout set for Y tensor"));
|
|
|
|
|
platform::errors::InvalidArgument("Wrong layout set for Y tensor."));
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
y->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
platform::errors::InvalidArgument("Wrong format set for Y tensor"));
|
|
|
|
|
platform::errors::InvalidArgument("Wrong format set for Y tensor."));
|
|
|
|
|
|
|
|
|
|
const auto src_x_tz = framework::vectorize(x->dims());
|
|
|
|
|
const auto src_y_tz = framework::vectorize(y->dims());
|
|
|
|
@ -774,10 +774,10 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
|
|
|
|
|
if (!this->isCached()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Wrong layout set for Input tensor"));
|
|
|
|
|
"Wrong layout set for Input tensor."));
|
|
|
|
|
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Wrong format set for Input tensor"));
|
|
|
|
|
"Wrong format set for Input tensor."));
|
|
|
|
|
|
|
|
|
|
const std::string pooling_type = ctx.Attr<std::string>("pooling_type");
|
|
|
|
|
|
|
|
|
@ -795,15 +795,21 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
|
|
|
|
|
ctx.Attr<std::string>("padding_algorithm");
|
|
|
|
|
|
|
|
|
|
// Only 2D pooling is supported now
|
|
|
|
|
PADDLE_ENFORCE_EQ(ksize.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"ksize must be 2D, i.e. 2D pooling"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(pooling_type == "max" || pooling_type == "avg", true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"pooling_type must be 'max' or 'avg'"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(input->dims().size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input dim must be with 4, i.e. NCHW"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ksize.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The ksize must be 2D, i.e. 2D pooling, but received %dD.",
|
|
|
|
|
ksize.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
pooling_type == "max" || pooling_type == "avg", true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The pooling_type must be 'max' or 'avg', but received %s.",
|
|
|
|
|
pooling_type));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input->dims().size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input dim must be with 4, i.e. NCHW, but received %d.",
|
|
|
|
|
input->dims().size()));
|
|
|
|
|
|
|
|
|
|
const auto input_dims = input->dims();
|
|
|
|
|
framework::DDim data_dims =
|
|
|
|
@ -1421,7 +1427,7 @@ static std::shared_ptr<mkldnn::memory> SetDstMemory(
|
|
|
|
|
residual_param_data,
|
|
|
|
|
platform::errors::PreconditionNotMet("Residual parameter is required for "
|
|
|
|
|
"the DNNL conv+elementwise_add "
|
|
|
|
|
"fusion, but now it is missing"));
|
|
|
|
|
"fusion, but now it is missing."));
|
|
|
|
|
std::shared_ptr<mkldnn::memory> user_residual_memory_p =
|
|
|
|
|
handler->AcquireResidualDataMemory(user_residual_md,
|
|
|
|
|
to_void_cast<T>(residual_param_data));
|
|
|
|
|