|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/pool_with_index_op.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -283,16 +284,33 @@ Example:
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<T> Apply() const override {
|
|
|
|
|
std::unique_ptr<T> op(new T());
|
|
|
|
|
op->SetType(this->ForwardOpType() + "_grad");
|
|
|
|
|
op->SetAttrMap(this->Attrs());
|
|
|
|
|
op->SetInput("X", this->Input("X"));
|
|
|
|
|
op->SetInput("Mask", this->Output("Mask"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
|
|
|
|
return op;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
max_pool2d_with_index, ops::MaxPoolWithIndexOp,
|
|
|
|
|
ops::MaxPool2dWithIndexOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
|
|
|
|
|
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
|
|
|
|
|
REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
|
|
|
|
|
ops::MaxPool2dWithIndexOpMaker,
|
|
|
|
|
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
@ -307,11 +325,10 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUDeviceContext, double,
|
|
|
|
|
int>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(
|
|
|
|
|
max_pool3d_with_index, ops::MaxPoolWithIndexOp,
|
|
|
|
|
ops::MaxPool3dWithIndexOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
|
|
|
|
|
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
|
|
|
|
|
REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
|
|
|
|
|
ops::MaxPool3dWithIndexOpMaker,
|
|
|
|
|
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|