|
|
|
@ -102,7 +102,9 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
|
|
|
|
|
ops::GatherOpKernel<int>, ops::GatherOpKernel<double>);
|
|
|
|
|
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
|
|
|
|
|
ops::GatherOpKernel<int64_t>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
|
|
|
|
|
ops::GatherGradientOpKernel<double>,
|
|
|
|
|
ops::GatherGradientOpKernel<int>,
|
|
|
|
|
ops::GatherGradientOpKernel<double>);
|
|
|
|
|
ops::GatherGradientOpKernel<int64_t>);
|
|
|
|
|