deconv2d kernel and deconv3d kernel write together

mobile_baidu
chengduoZH 8 years ago
parent 0f1b30ef86
commit 206f32c13a

@ -44,7 +44,7 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);

@ -187,17 +187,17 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
REGISTER_OP_CPU_KERNEL(
conv2d_transpose,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose,
ops::GemmConv3DTransposeKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad,
ops::GemmConv3DTransposeGradKernel<paddle::platform::CPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);

@ -18,14 +18,14 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
conv2d_transpose,
ops::GemmConv2DTransposeKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv2d_transpose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv3d_transpose,
ops::GemmConv3DTransposeKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
conv3d_transpose_grad,
ops::GemmConv3DTransposeGradKernel<paddle::platform::GPUPlace, float>);
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save