|
|
|
@ -249,6 +249,19 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class SqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
void Apply(GradOpPtr<T> grad_op) const override {
|
|
|
|
|
grad_op->SetType("squeeze");
|
|
|
|
|
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
|
|
|
|
|
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
grad_op->SetAttrMap(this->Attrs());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// FIXME(zcd): squeeze2 adds an intermediate output(XShape) based on squeeze,
|
|
|
|
|
// the XShape is used to carry the shape and lod of X which will be used in
|
|
|
|
|
// squeeze_grad, in this way, the framework can reuse the memory of X
|
|
|
|
@ -279,8 +292,22 @@ class Squeeze2GradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(SequeezeInplaceInferer, {"X", "Out"});
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(SequeezeGradInplaceInferer,
|
|
|
|
|
template <typename T>
|
|
|
|
|
class Squeeze2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
void Apply(GradOpPtr<T> grad_op) const override {
|
|
|
|
|
grad_op->SetType("squeeze2");
|
|
|
|
|
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
|
|
|
|
|
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
grad_op->SetOutput("XShape", this->Input("XShape"));
|
|
|
|
|
grad_op->SetAttrMap(this->Attrs());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(SqueezeInplaceInferer, {"X", "Out"});
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(SqueezeGradInplaceInferer,
|
|
|
|
|
{framework::GradVarName("Out"),
|
|
|
|
|
framework::GradVarName("X")});
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X");
|
|
|
|
@ -292,14 +319,18 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
|
|
|
|
|
ops::SqueezeGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::SqueezeGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp,
|
|
|
|
|
ops::SqueezeDoubleGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::SqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
|
|
|
|
|
ops::SqueezeGradNoNeedBufferVarsInferer);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(squeeze2, ops::Squeeze2Op, ops::Squeeze2OpMaker,
|
|
|
|
|
ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::Squeeze2GradOpMaker<paddle::imperative::OpBase>,
|
|
|
|
|
ops::SequeezeInplaceInferer);
|
|
|
|
|
ops::SqueezeInplaceInferer);
|
|
|
|
|
REGISTER_OPERATOR(squeeze2_grad, ops::Squeeze2GradOp,
|
|
|
|
|
ops::SequeezeGradInplaceInferer);
|
|
|
|
|
ops::Squeeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::Squeeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
|
|
|
|
|
ops::SqueezeGradInplaceInferer);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
squeeze, ops::SqueezeKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|