|
|
|
@ -22,14 +22,8 @@ int OutputSizePool(int input_size, int filter_size, int padding, int stride) {
|
|
|
|
|
return output_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class PoolOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"X(Input) of Pooling should not be null.");
|
|
|
|
|
void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Out(Output) of Pooling should not be null.");
|
|
|
|
|
|
|
|
|
@ -63,24 +57,16 @@ class PoolOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class PoolOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
|
"Input(X@GRAD) should not be null.");
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput(
|
|
|
|
|
"X",
|
|
|
|
@ -132,11 +118,9 @@ Parameters(ksize, strides, paddings) are two elements.
|
|
|
|
|
These two elements represent height and width, respectively.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput(
|
|
|
|
|
"X",
|
|
|
|
@ -169,8 +153,7 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"Default false."
|
|
|
|
|
"If globalPooling = true, ksize is ignored and need not be specified.")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<std::vector<int>>(
|
|
|
|
|
"strides",
|
|
|
|
|
AddAttr<std::vector<int>>("strides",
|
|
|
|
|
"Strides(depth, height, width) of pooling operator."
|
|
|
|
|
"Default {1,1,1}.")
|
|
|
|
|
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
|
|
|
|
@ -191,7 +174,6 @@ width of feature. Parameters(ksize, strides, paddings) are three elements.
|
|
|
|
|
These three elements represent depth, height and width, respectively.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|