add int and int64 dtype for gather_op (#14175)

test=develop
fix_recordio_link
chengduo 6 years ago committed by GitHub
parent 62a0fe0860
commit b73708d20b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -102,7 +102,9 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<int>, ops::GatherOpKernel<double>);
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<double>,
ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<double>);
ops::GatherGradientOpKernel<int64_t>);

@ -61,5 +61,11 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
ops::GatherOpCUDAKernel<double>,
ops::GatherOpCUDAKernel<int64_t>,
ops::GatherOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
ops::GatherGradOpCUDAKernel<double>,
ops::GatherGradOpCUDAKernel<int64_t>,
ops::GatherGradOpCUDAKernel<int>);

Loading…
Cancel
Save