|
|
|
@ -40,7 +40,7 @@ int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
|
|
|
|
|
return output_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Out(Output) of Pooling should not be null.");
|
|
|
|
@ -81,7 +81,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType PoolOp::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const {
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
@ -104,7 +104,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
|
|
|
|
|
layout_, library_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
|
"Input(X@GRAD) should not be null.");
|
|
|
|
@ -112,7 +112,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const {
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
@ -262,6 +262,14 @@ Example:
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
|
|
|
|
|
protected:
|
|
|
|
|
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
|
|
|
|
|
const override {
|
|
|
|
|
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void Pool3dOpMaker::Make() {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor) The input tensor of pooling operator. "
|
|
|
|
@ -372,6 +380,7 @@ Example:
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(pool2d, ops::PoolOp, ops::Pool2dOpMaker,
|
|
|
|
|
ops::PoolOpInferVarType,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad);
|
|
|
|
|
|
|
|
|
@ -383,6 +392,7 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
ops::PoolGradKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(pool3d, ops::PoolOp, ops::Pool3dOpMaker,
|
|
|
|
|
ops::PoolOpInferVarType,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(pool3d_grad, ops::PoolOpGrad);
|
|
|
|
|
|
|
|
|
|