You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/operators/unpool_op.cc

142 lines
5.7 KiB

7 years ago
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
7 years ago
#include "paddle/operators/unpool_op.h"
namespace paddle {
namespace operators {
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Unpool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker)
7 years ago
: OpProtoAndCheckerMaker(proto, op_checker) {
7 years ago
AddInput(
"X",
7 years ago
"(Tensor) The input tensor of unpool operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
7 years ago
AddInput(
"Indices",
7 years ago
"(Tensor) The input tensor of the indices given out by MaxPool2d. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
7 years ago
AddOutput("Out",
7 years ago
"(Tensor) The output tensor of unpool operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
7 years ago
AddAttr<std::vector<int>>(
"ksize",
"(vector), the unpooling window size(height, width) "
7 years ago
"of unpooling operator.");
7 years ago
AddAttr<std::vector<int>>("strides",
"(vector, default:{1, 1}), "
"strides (height, width) of unpooling operator.")
7 years ago
.SetDefault({1, 1});
7 years ago
AddAttr<std::vector<int>>("paddings",
"(vector defalut:{0,0}), "
"paddings (height, width) of unpooling operator.")
7 years ago
.SetDefault({0, 0});
7 years ago
AddAttr<std::string>(
"unpooling_type",
7 years ago
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
7 years ago
AddComment(R"DOC(
7 years ago
Input shape is: $(N, C_{in}, H_{in}, W_{in})$, Output shape is:
$(N, C_{out}, H_{out}, W_{out})$, where
7 years ago
$$
H_{out} = (H_{in}1) * strides[0] 2 * paddings[0] + ksize[0] \\
W_{out} = (W_{in}1) * strides[1] 2 * paddings[1] + ksize[1]
$$
Paper: http://www.matthewzeiler.com/wp-content/uploads/2017/07/iccv2011.pdf
)DOC");
7 years ago
}
};
int OutputSize(int input_size, int ksize, int padding, int stride) {
7 years ago
int output_size = (input_size - 1) * stride - 2 * padding + ksize;
7 years ago
return output_size;
}
class UnpoolOp : public framework::OperatorWithKernel {
7 years ago
protected:
framework::OpKernelType GetExpectedKernelType(
7 years ago
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
7 years ago
ctx.device_context());
7 years ago
}
7 years ago
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
7 years ago
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Indices"),
"Input(Indices) of UnpoolOp"
7 years ago
"should not be null.");
7 years ago
PADDLE_ENFORCE(ctx->HasOutput("Out"),
7 years ago
"Output(Out) of UnpoolOp should not be null.");
7 years ago
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Indices");
7 years ago
std::string unpooling_type =
ctx->Attrs().Get<std::string>("unpooling_type");
7 years ago
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
7 years ago
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
7 years ago
PADDLE_ENFORCE(in_x_dims.size() == 4,
7 years ago
"Unpooling intput must be of 4-dimensional.");
7 years ago
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
7 years ago
OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
7 years ago
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
7 years ago
};
class UnpoolOpGrad : public framework::OperatorWithKernel {
7 years ago
protected:
framework::OpKernelType GetExpectedKernelType(
7 years ago
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
7 years ago
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
7 years ago
}
7 years ago
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
7 years ago
"Input(X@GRAD) should not be null.");
7 years ago
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
7 years ago
};
7 years ago
} // namespace operators
} // namespace paddle
7 years ago
namespace ops = paddle::operators;
REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad,
7 years ago
ops::UnpoolOpGrad);
7 years ago
REGISTER_OP_CPU_KERNEL(
unpool, ops::UnpoolKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnpoolKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
unpool_grad,
ops::UnpoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnpoolGradKernel<paddle::platform::CPUDeviceContext, double>);