@ -618,25 +618,25 @@ REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops : : ReshapeDoubleGradInplaceInferer ,
ops : : ReshapeDoubleGradOpNoNeedBufferVarInferer ) ;
REGISTER_OP_CPU_KERNEL_FUNCTOR ( reshape2 , float , ops : : ReshapeKernel , double ,
ops : : ReshapeKernel , int8_t , ops : : ReshapeKernel ,
uint8_t , ops : : ReshapeKernel , int ,
ops : : ReshapeKernel , int64_t , ops : : ReshapeKernel ,
bool , ops : : ReshapeKernel ,
paddle : : platform : : bfloat16 , ops : : ReshapeKernel ) ;
REGISTER_OP_CPU_KERNEL_FUNCTOR ( reshape2_grad , float , ops : : ReshapeGradKernel ,
double , ops : : ReshapeGradKernel , int ,
ops : : ReshapeGradKernel , uint8_t ,
ops : : ReshapeGradKernel , int64_t ,
ops : : ReshapeGradKernel , boo l,
ops : : ReshapeGradKernel ) ;
REGISTER_OP_CPU_KERNEL_FUNCTOR ( reshape2_grad_grad , float ,
ops : : ReshapeDoubleGradKernel , double ,
ops : : ReshapeDoubleGradKernel , int,
ops : : ReshapeDoubleGradKernel , uint8_t ,
ops : : ReshapeDoubleGradKernel , int64_t ,
ops : : ReshapeDoubleGradKernel , bool ,
REGISTER_OP_CPU_KERNEL_FUNCTOR (
reshape2 , float , ops : : ReshapeKernel , double , ops : : ReshapeKernel , int8_t ,
ops : : ReshapeKernel , uint8_t , ops : : ReshapeKernel , int , ops : : ReshapeKernel ,
int64_t , ops : : ReshapeKernel , bool , ops : : ReshapeKernel ,
paddle : : platform : : bfloat16 , ops : : ReshapeKernel , paddle : : platform : : complex64 ,
ops : : ReshapeKernel , paddle : : platform : : complex128 , ops : : ReshapeKernel ) ;
REGISTER_OP_CPU_KERNEL_FUNCTOR (
reshape2_grad , float , ops : : ReshapeGradKernel , double ,
ops : : ReshapeGradKernel , int , ops : : ReshapeGradKernel , uint8_t ,
ops : : ReshapeGradKernel , int64_t , ops : : ReshapeGradKernel , bool ,
ops : : ReshapeGradKernel , paddle : : platform : : complex64 , ops : : ReshapeGradKerne l,
paddle : : platform : : complex128 , ops : : ReshapeGradKernel ) ;
REGISTER_OP_CPU_KERNEL_FUNCTOR (
reshape2_grad_grad , float , ops : : ReshapeDoubleGradKernel , double ,
ops : : ReshapeDoubleGradKernel , int , ops : : ReshapeDoubleGradKernel , u int8_ t,
ops : : ReshapeDoubleGradKernel , int64_t, ops : : ReshapeDoubleGradKernel , bool ,
ops : : ReshapeDoubleGradKernel , paddle : : platform : : complex64 ,
ops : : ReshapeDoubleGradKernel , paddle : : platform : : complex128 ,
ops : : ReshapeDoubleGradKernel ) ;
# ifdef PADDLE_WITH_CUDA
@ -656,34 +656,38 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops : : ReshapeKernel , int , ops : : ReshapeKernel ,
uint8_t , ops : : ReshapeKernel , int64_t ,
ops : : ReshapeKernel , plat : : float16 ,
ops : : ReshapeKernel , bool , ops : : ReshapeKernel );
REGISTER_OP_CUDA_KERNEL_FUNCTOR ( reshape2_grad , float , ops : : Reshape Grad Kernel,
double , ops : : ReshapeGradKernel , int ,
ops : : ReshapeGradKernel , uint8_t ,
ops : : ReshapeGradKernel , int64_t ,
ops : : ReshapeGradKernel , plat : : float16 ,
ops : : ReshapeGradKernel , bool ,
ops : : ReshapeGradKernel ) ;
REGISTER_OP_CUDA_KERNEL_FUNCTOR ( reshape2_grad_grad , float ,
ops : : ReshapeDoubleGradKernel , double ,
ops : : ReshapeDoubleGradKernel , int ,
ops : : ReshapeDoubleGradKernel , uint8_t ,
ops : : ReshapeDoubleGradKernel , int64_t ,
ops : : ReshapeDoubleGradKernel , plat : : float16 ,
ops : : ReshapeDoubleGradKernel , boo l,
ops : : ReshapeDoubleGradKernel ) ;
ops : : ReshapeKernel , bool , ops : : ReshapeKernel ,
plat : : complex64 , ops : : Reshape Kernel,
plat : : complex128 , ops : : ReshapeKernel ) ;
REGISTER_OP_CUDA_KERNEL_FUNCTOR (
reshape2_grad , float , ops : : ReshapeGradKernel , double ,
ops : : ReshapeGradKernel , int , ops : : ReshapeGradKernel , uint8_t ,
ops : : ReshapeGradKernel , int64_t, ops : : ReshapeGradKernel , plat : : float16 ,
ops : : ReshapeGradKernel , bool , ops : : ReshapeGradKernel , plat : : complex64 ,
ops : : ReshapeGradKernel , plat : : complex128 , ops : : ReshapeGradKernel ) ;
REGISTER_OP_CUDA_KERNEL_FUNCTOR (
reshape2_grad_grad , float , ops : : ReshapeDoubleGradKernel , double ,
ops : : ReshapeDoubleGradKernel , int , ops : : ReshapeDoubleGradKernel , uint8_t ,
ops : : ReshapeDoubleGradKernel , int64_t , ops : : ReshapeDoubleGradKernel ,
plat : : float16 , ops : : ReshapeDoubleGradKernel , bool ,
ops : : ReshapeDoubleGradKernel , plat : : complex64 , ops : : ReshapeDoubleGradKerne l,
plat : : complex128 , ops : : ReshapeDoubleGradKernel ) ;
# endif
# ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL_FUNCTOR ( reshape2 , float , ops : : ReshapeKernel , double ,
ops : : ReshapeKernel , int , ops : : ReshapeKernel ,
int64_t , ops : : ReshapeKernel , plat : : float16 ,
ops : : ReshapeKernel , bool , ops : : ReshapeKernel ) ;
ops : : ReshapeKernel , bool , ops : : ReshapeKernel ,
plat : : complex64 , ops : : ReshapeKernel ,
plat : : complex128 , ops : : ReshapeKernel ) ;
REGISTER_OP_XPU_KERNEL_FUNCTOR ( reshape2_grad , float , ops : : ReshapeGradKernel ,
double , ops : : ReshapeGradKernel , int ,
ops : : ReshapeGradKernel , int64_t ,
ops : : ReshapeGradKernel , plat : : float16 ,
ops : : ReshapeGradKernel , bool ,
ops : : ReshapeGradKernel , plat : : complex64 ,
ops : : ReshapeGradKernel , plat : : complex128 ,
ops : : ReshapeGradKernel ) ;
# endif