add double type kernel

release/0.11.0
chengduoZH 7 years ago
parent 0bc2f41da9
commit c359e39b59

@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d, REGISTER_OP_CPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>); ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>); conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(conv3d, REGISTER_OP_CPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>); ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>); conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);

@ -17,11 +17,15 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d, REGISTER_OP_GPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>); ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>); conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(conv3d, REGISTER_OP_GPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>); ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>); conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);

@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose, conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad, conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
conv3d_transpose_grad, ops::ConvTransposeOpGrad); conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose, conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad, conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);

@ -18,14 +18,18 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_transpose, conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_transpose_grad, conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d_transpose, conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d_transpose_grad, conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);

Loading…
Cancel
Save