|
|
|
@ -495,13 +495,21 @@ class Pad2dOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(paddings.size(), 4,
|
|
|
|
|
"Size of paddings should be equal to 4.");
|
|
|
|
|
if (data_format == "NCHW") {
|
|
|
|
|
out_dims[1] = x_dim[1];
|
|
|
|
|
out_dims[2] = x_dim[2] + paddings[0] + paddings[1]; // height
|
|
|
|
|
out_dims[3] = x_dim[3] + paddings[2] + paddings[3]; // width
|
|
|
|
|
} else { // NHWC
|
|
|
|
|
out_dims[3] = x_dim[3];
|
|
|
|
|
out_dims[1] = x_dim[1] + paddings[0] + paddings[1];
|
|
|
|
|
out_dims[2] = x_dim[2] + paddings[2] + paddings[3];
|
|
|
|
|
out_dims[1] = x_dim[1]; // channel
|
|
|
|
|
out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
|
|
|
|
|
? x_dim[2]
|
|
|
|
|
: (x_dim[2] + paddings[0] + paddings[1]); // height
|
|
|
|
|
out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0))
|
|
|
|
|
? x_dim[3]
|
|
|
|
|
: (x_dim[3] + paddings[2] + paddings[3]); // width
|
|
|
|
|
} else { // NHWC
|
|
|
|
|
out_dims[3] = x_dim[3]; // channel
|
|
|
|
|
out_dims[1] = ((!ctx->IsRuntime()) && (x_dim[1] < 0))
|
|
|
|
|
? x_dim[1]
|
|
|
|
|
: (x_dim[1] + paddings[0] + paddings[1]); // height
|
|
|
|
|
out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
|
|
|
|
|
? x_dim[2]
|
|
|
|
|
: (x_dim[2] + paddings[2] + paddings[3]); // width
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|