modify for code review by wangyi

release/0.11.0
sweetsky0901 7 years ago
parent e553d5728d
commit 66b84366f1

@ -16,11 +16,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Unpool2dOpMaker(framework::OpProto* proto, \ Unpool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
@ -38,26 +36,26 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"the number of channels, H and W is the height and " "the number of channels, H and W is the height and "
"width of feature."); "width of feature.");
AddAttr<std::vector<int>>("ksize", AddAttr<std::vector<int>>("ksize",
"(vector ), the unpooling window size(height, width) " "(vector), the unpooling window size(height, width) "
"of unpooling operator."); "of unpooling operator.");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector, default:{1, 1}), " "(vector, default:{1, 1}), "
"strides(height, width) of unpooling operator.") "strides (height, width) of unpooling operator.")
.SetDefault({1, 1}); .SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>("paddings",
"(vector defalut:{0,0}), " "(vector defalut:{0,0}), "
"paddings(height, width) of unpooling operator.") "paddings (height, width) of unpooling operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddAttr<std::string>("unpoolingtype", AddAttr<std::string>("unpoolingtype",
"(string), unpooling type, can be \"max\" for max-unpooling ") "(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"}); .InEnum({"max"});
AddComment(R"DOC( AddComment(R"DOC(
"input: the input Tensor to invert" "input: the input Tensor to invert
"indices: the indices given out by MaxPool2d" indices: the indices given out by MaxPool2d
"ksize Size of the max pooling window." ksize Size of the max pooling window.
"stride Stride of the max pooling window." stride Stride of the max pooling window.
"It is set to kernel_size by default." "It is set to kernel_size by default.
"padding Padding that was added to the input" padding Padding that was added to the input"
)DOC"); )DOC");
} }
}; };
@ -80,14 +78,14 @@ class UnpoolOp : public framework::OperatorWithKernel {
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y"); auto in_y_dims = ctx->GetInputDim("Y");
std::string unpoolingtype = \ std::string unpoolingtype =
ctx->Attrs().Get<std::string>("unpoolingtype"); ctx->Attrs().Get<std::string>("unpoolingtype");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize"); std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4, PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput should be 4-D."); "Unpooling intput must be of 4-dimensional.");
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i], PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i],
"X size must be eq Y size!"); "X size must be eq Y size!");

@ -21,15 +21,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class UnpoolKernel : public framework::OpKernel<T> { class UnpoolKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X"); const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y"); const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
auto * out = context.Output<Tensor>("Out"); auto * out = context.Output<framework::Tensor>("Out");
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype"); std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
@ -39,15 +37,8 @@ class UnpoolKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), out, static_cast<T>(0)); set_zero(context.device_context(), out, static_cast<T>(0));
} }
switch (ksize.size()) { math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
case 2: { unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
if (unpoolingtype == "max") {
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D input."); }
}
} }
}; };
@ -55,12 +46,13 @@ template <typename Place, typename T>
class UnpoolGradKernel : public framework::OpKernel<T> { class UnpoolGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X"); const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const Tensor* in_y = context.Input<Tensor>("Y"); const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
const Tensor* out = context.Input<Tensor>("Out"); const framework::Tensor* out = context.Input<framework::Tensor>("Out");
const Tensor* out_grad = const framework::Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out")); context.Input<framework::Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype"); std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
@ -70,18 +62,11 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> zero; math::SetConstant<Place, T> zero;
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0)); zero(device_ctx, in_x_grad, static_cast<T>(0));
}
switch (ksize.size()) {
case 2: {
if (unpoolingtype == "max") {
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad,
*out, *out_grad);
}
} break;
default: { PADDLE_THROW("Unpool op only supports 2D input."); }
} }
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad,
*out, *out_grad);
} }
}; };

Loading…
Cancel
Save