|
|
|
@ -61,10 +61,12 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
conv2d_transpose_cudnn,
|
|
|
|
|
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
|
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
conv2d_transpose_cudnn_grad,
|
|
|
|
|
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
|
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
|
|
|
|
|
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
|
|
|
|
@ -72,7 +74,9 @@ REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
conv3d_transpose_cudnn,
|
|
|
|
|
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
|
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
conv3d_transpose_cudnn_grad,
|
|
|
|
|
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
|
|
|
|
|
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
|
|
|
|
|