|
|
|
@ -20,8 +20,12 @@ namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(
|
|
|
|
|
elementwise_mul,
|
|
|
|
|
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>,
|
|
|
|
|
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, double>);
|
|
|
|
|
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, double>,
|
|
|
|
|
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, int>,
|
|
|
|
|
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, int64_t>);
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(
|
|
|
|
|
elementwise_mul_grad,
|
|
|
|
|
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>,
|
|
|
|
|
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, double>);
|
|
|
|
|
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, double>,
|
|
|
|
|
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, int>,
|
|
|
|
|
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, int64_t>);
|
|
|
|
|