|
|
|
@ -330,6 +330,7 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
namespace plat = paddle::platform;
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
|
|
|
|
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
@ -356,16 +357,20 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
int64_t, ops::ReshapeKernel, plat::float16,
|
|
|
|
|
|
|
|
ops::ReshapeKernel);
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
|
|
|
ops::ReshapeGradKernel, plat::float16,
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
int64_t, ops::ReshapeKernel, plat::float16,
|
|
|
|
|
|
|
|
ops::ReshapeKernel);
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
|
|
|
ops::ReshapeGradKernel, plat::float16,
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|