modify for code review by wangyi

release/0.11.0
sweetsky0901 7 years ago
parent e553d5728d
commit 66b84366f1

@ -16,11 +16,9 @@
namespace paddle {
namespace operators {
using framework::Tensor;
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Unpool2dOpMaker(framework::OpProto* proto, \
Unpool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
@ -52,12 +50,12 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddComment(R"DOC(
"input: the input Tensor to invert"
"indices: the indices given out by MaxPool2d"
"ksize Size of the max pooling window."
"stride Stride of the max pooling window."
"It is set to kernel_size by default."
"padding Padding that was added to the input"
"input: the input Tensor to invert
indices: the indices given out by MaxPool2d
ksize Size of the max pooling window.
stride Stride of the max pooling window.
"It is set to kernel_size by default.
padding Padding that was added to the input"
)DOC");
}
};
@ -80,14 +78,14 @@ class UnpoolOp : public framework::OperatorWithKernel {
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y");
std::string unpoolingtype = \
std::string unpoolingtype =
ctx->Attrs().Get<std::string>("unpoolingtype");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput should be 4-D.");
"Unpooling intput must be of 4-dimensional.");
for (int i = 0; i < 4; ++i) {
PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i],
"X size must be eq Y size!");

@ -21,15 +21,13 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class UnpoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y");
auto * out = context.Output<Tensor>("Out");
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
auto * out = context.Output<framework::Tensor>("Out");
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
@ -39,28 +37,22 @@ class UnpoolKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), out, static_cast<T>(0));
}
switch (ksize.size()) {
case 2: {
if (unpoolingtype == "max") {
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D input."); }
}
}
};
template <typename Place, typename T>
class UnpoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y");
const Tensor* out = context.Input<Tensor>("Out");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
const framework::Tensor* out = context.Input<framework::Tensor>("Out");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
@ -70,19 +62,12 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
zero(device_ctx, in_x_grad, static_cast<T>(0));
}
switch (ksize.size()) {
case 2: {
if (unpoolingtype == "max") {
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad,
*out, *out_grad);
}
} break;
default: { PADDLE_THROW("Unpool op only supports 2D input."); }
}
}
};
} // namespace operators

Loading…
Cancel
Save