|
|
|
@ -99,10 +99,15 @@ class UnpoolOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(in_x_dims.size() == 4,
|
|
|
|
|
"Unpooling intput must be of 4-dimensional.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
|
|
|
|
|
for (size_t i = 0; i < ksize.size(); ++i) {
|
|
|
|
|
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
|
|
|
|
|
paddings[i], strides[i]));
|
|
|
|
|
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
|
|
|
|
|
output_shape.push_back(-1);
|
|
|
|
|
} else {
|
|
|
|
|
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
|
|
|
|
|
paddings[i], strides[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
|
|
|
|
|
}
|
|
|
|
|