|
|
|
@ -621,15 +621,18 @@ REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int8_t, ops::ReshapeKernel,
|
|
|
|
|
uint8_t, ops::ReshapeKernel, int,
|
|
|
|
|
ops::ReshapeKernel, int64_t, ops::ReshapeKernel);
|
|
|
|
|
ops::ReshapeKernel, int64_t, ops::ReshapeKernel,
|
|
|
|
|
bool, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel, bool,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad_grad, float,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, double,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, int,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, bool,
|
|
|
|
|
ops::ReshapeDoubleGradKernel);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
@ -641,15 +644,17 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel, plat::float16,
|
|
|
|
|
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel, plat::float16,
|
|
|
|
|
ops::ReshapeKernel);
|
|
|
|
|
ops::ReshapeKernel, bool, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel, plat::float16,
|
|
|
|
|
ops::ReshapeGradKernel, bool,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
|
|
|
|
@ -657,6 +662,7 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, int,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, plat::float16,
|
|
|
|
|
ops::ReshapeDoubleGradKernel, bool,
|
|
|
|
|
ops::ReshapeDoubleGradKernel);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
@ -664,10 +670,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
|
|
|
|
|
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel, plat::float16,
|
|
|
|
|
ops::ReshapeKernel);
|
|
|
|
|
ops::ReshapeKernel, bool, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel, plat::float16,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
ops::ReshapeGradKernel,
|
|
|
|
|
bool ops::ReshapeGradKernel);
|
|
|
|
|
#endif
|
|
|
|
|