|
|
|
@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
ScopedTensorDescriptor output_desc;
|
|
|
|
|
ScopedFilterDescriptor filter_desc;
|
|
|
|
|
ScopedConvolutionDescriptor conv_desc;
|
|
|
|
|
DataLayout layout = DataLayout::kNCHW;
|
|
|
|
|
DataLayout layout;
|
|
|
|
|
|
|
|
|
|
if (strides.size() == 2U) {
|
|
|
|
|
layout = DataLayout::kNCHW;
|
|
|
|
|
} else {
|
|
|
|
|
layout = DataLayout::kNCDHW;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// N, M, H, W
|
|
|
|
|
// (N, M, H, W) or (N, M, D, H, W)
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(input->dims()));
|
|
|
|
|
// N, C, O_h, O_w
|
|
|
|
|
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(output->dims()));
|
|
|
|
|
// M, C, K_h, K_w
|
|
|
|
|
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
|
|
|
|
|
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(filter->dims()));
|
|
|
|
|
cudnnConvolutionDescriptor_t cudnn_conv_desc =
|
|
|
|
@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
ScopedConvolutionDescriptor conv_desc;
|
|
|
|
|
DataLayout layout = DataLayout::kNCHW;
|
|
|
|
|
|
|
|
|
|
// Input: (N, M, H, W)
|
|
|
|
|
// Input: (N, M, H, W) or (N, M, D, H, W)
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(input->dims()));
|
|
|
|
|
// Output: (N, C, O_H, O_W)
|
|
|
|
|
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
|
|
|
|
|
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(output_grad->dims()));
|
|
|
|
|
// Filter (M, C, K_H, K_W)
|
|
|
|
|
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
|
|
|
|
|
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
|
|
|
|
|
layout, framework::vectorize2int(filter->dims()));
|
|
|
|
|
|
|
|
|
@ -200,8 +206,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
T alpha = 1.0f, beta = 0.0f;
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::set_constant(ctx.device_context(), input_grad, 0);
|
|
|
|
|
|
|
|
|
|
// Because beta is zero, it is unnecessary to reset input_grad.
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
|
|
|
|
|
handle, &alpha, cudnn_output_desc, output_grad_data,
|
|
|
|
|
cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo,
|
|
|
|
@ -212,8 +217,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// ------------------- cudnn conv backward filter ---------------------
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::set_constant(ctx.device_context(), filter_grad, 0);
|
|
|
|
|
|
|
|
|
|
// Because beta is zero, it is unnecessary to reset filter_grad.
|
|
|
|
|
// Gradient with respect to the filter
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
|
|
|
|
|
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
|
|
|
|
@ -234,3 +238,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
|
|
|
|
|
ops::CudnnConvTransposeOpKernel<float>);
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
|
|
|
|
|
ops::CudnnConvTransposeGradOpKernel<float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
|
|
|
|
|
ops::CudnnConvTransposeOpKernel<float>);
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
|
|
|
|
|
ops::CudnnConvTransposeGradOpKernel<float>);
|