|
|
|
@ -545,12 +545,12 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"});
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut,
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInferer,
|
|
|
|
|
{framework::GradVarName("Out"),
|
|
|
|
|
framework::GradVarName("X")});
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"});
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReshapeDoubleGradOpNoNeedBufferVarInference,
|
|
|
|
|
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInferer, {"DDX", "DDOut"});
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReshapeDoubleGradOpNoNeedBufferVarInferer,
|
|
|
|
|
"DOut");
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
@ -562,9 +562,9 @@ REGISTER_OPERATOR(
|
|
|
|
|
reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
|
|
|
|
|
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
|
|
|
|
|
ops::ReshapeOpInplaceInToOut);
|
|
|
|
|
ops::ReshapeOpInplaceInferer);
|
|
|
|
|
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp,
|
|
|
|
|
ops::ReshapeGradInplaceInToOut);
|
|
|
|
|
ops::ReshapeGradInplaceInferer);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
@ -576,14 +576,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
|
|
|
|
|
ops::Reshape2GradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::Reshape2GradMaker<paddle::imperative::OpBase>,
|
|
|
|
|
ops::ReshapeOpInplaceInToOut);
|
|
|
|
|
ops::ReshapeOpInplaceInferer);
|
|
|
|
|
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
|
|
|
|
|
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
|
|
|
|
|
ops::ReshapeGradInplaceInToOut);
|
|
|
|
|
ops::ReshapeGradInplaceInferer);
|
|
|
|
|
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
|
|
|
|
|
ops::ReshapeDoubleGradInplaceInToOut,
|
|
|
|
|
ops::ReshapeDoubleGradOpNoNeedBufferVarInference);
|
|
|
|
|
ops::ReshapeDoubleGradInplaceInferer,
|
|
|
|
|
ops::ReshapeDoubleGradOpNoNeedBufferVarInferer);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int8_t, ops::ReshapeKernel,
|
|
|
|
|