|
|
|
@ -88,19 +88,26 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Input(Mask) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
|
"Input(X@GRAD) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("Mask"), true,
|
|
|
|
|
platform::errors::NotFound("Input(Mask) must not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::NotFound("Input(X) must not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
platform::errors::NotFound("Input(Out@GRAD) should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput(framework::GradVarName("X")), true,
|
|
|
|
|
platform::errors::NotFound("Output(X@GRAD) should not be null."));
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
|
|
|
|
ctx, framework::GradVarName("Out")),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -302,6 +309,9 @@ class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
|
|
|
|
|
MaxPoolWithIndexOpGradNoNeedBufferVarsInference, "X");
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -311,7 +321,8 @@ REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
|
|
|
|
|
ops::MaxPool2dWithIndexOpMaker,
|
|
|
|
|
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad);
|
|
|
|
|
REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
|
|
|
|
|
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
max_pool2d_with_index,
|
|
|
|
@ -329,7 +340,8 @@ REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
|
|
|
|
|
ops::MaxPool3dWithIndexOpMaker,
|
|
|
|
|
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad);
|
|
|
|
|
REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
|
|
|
|
|
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
max_pool3d_with_index,
|
|
|
|
|