|
|
|
@ -61,10 +61,12 @@ class UnfoldOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of UnfoldOp should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Y"),
|
|
|
|
|
"Output(Y) of UnfoldOp should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::NotFound("Input(X) of UnfoldOp should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("Y"), true,
|
|
|
|
|
platform::errors::NotFound("Output(Y) of UnfoldOp should not be null"));
|
|
|
|
|
auto in_dims = ctx->GetInputDim("X");
|
|
|
|
|
std::vector<int> kernel_sizes =
|
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("kernel_sizes");
|
|
|
|
@ -74,31 +76,36 @@ class UnfoldOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("dilations");
|
|
|
|
|
|
|
|
|
|
// Only [N, C, H, W] input supported now
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
in_dims.size() == 4,
|
|
|
|
|
"Input should be 4-D tensor of format [N, C, H, W], but get %u",
|
|
|
|
|
in_dims.size());
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
in_dims.size() - kernel_sizes.size() == 2U,
|
|
|
|
|
"The dims of X should be larger than that of kernel_sizes "
|
|
|
|
|
"by a number of 2, due to the batch size and input channel dim. "
|
|
|
|
|
"But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2",
|
|
|
|
|
in_dims.size(), kernel_sizes.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dims.size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input should be 4-D tensor of format [N, C, H, W], but get %u",
|
|
|
|
|
in_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dims.size() - kernel_sizes.size(), 2U,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of X should be larger than that of kernel_sizes "
|
|
|
|
|
"by a number of 2, due to the batch size and input channel dim. "
|
|
|
|
|
"But recieved dims(X:%u) - dims(kernel_sizes:%u) != 2",
|
|
|
|
|
in_dims.size(), kernel_sizes.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
strides.size(), kernel_sizes.size(),
|
|
|
|
|
"The dims of strides should be the same with that of kernel_sizes. "
|
|
|
|
|
"But recieved dims(strides: %u) != dims(kernel_sizes: %u).",
|
|
|
|
|
strides.size(), kernel_sizes.size());
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of strides should be the same with that of kernel_sizes. "
|
|
|
|
|
"But recieved dims(strides: %u) != dims(kernel_sizes: %u).",
|
|
|
|
|
strides.size(), kernel_sizes.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
paddings.size(), 2 * strides.size(),
|
|
|
|
|
"The dims of paddings should be 2 times of that of strides. "
|
|
|
|
|
"But recieved dims(paddings: %u) != 2*dims(strides: %u).",
|
|
|
|
|
paddings.size(), strides.size());
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of paddings should be 2 times of that of strides. "
|
|
|
|
|
"But recieved dims(paddings: %u) != 2*dims(strides: %u).",
|
|
|
|
|
paddings.size(), strides.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
strides.size(), dilations.size(),
|
|
|
|
|
"The dims of strides should be the same with that of dilations. "
|
|
|
|
|
"But recieved dims(strides: %u) != dims(dilations: %u).",
|
|
|
|
|
strides.size(), dilations.size());
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dims of strides should be the same with that of dilations. "
|
|
|
|
|
"But recieved dims(strides: %u) != dims(dilations: %u).",
|
|
|
|
|
strides.size(), dilations.size()));
|
|
|
|
|
|
|
|
|
|
std::vector<int> out_dims;
|
|
|
|
|
out_dims.push_back(in_dims[0]);
|
|
|
|
@ -131,11 +138,15 @@ class UnfoldGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
|
|
|
|
"The gradient of Y should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
|
"The gradient of X should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput(framework::GradVarName("Y")), true,
|
|
|
|
|
platform::errors::NotFound("The gradient of Y should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::NotFound("The input X should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput(framework::GradVarName("X")), true,
|
|
|
|
|
platform::errors::NotFound("The gradient of X should not be null"));
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|