|
|
@ -64,13 +64,13 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
// (N, M, H, W) or (N, M, D, H, W)
|
|
|
|
// (N, M, H, W) or (N, M, D, H, W)
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
|
|
|
layout, framework::vectorize2int(input->dims()), groups);
|
|
|
|
layout, framework::vectorize<int>(input->dims()), groups);
|
|
|
|
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
|
|
|
|
// (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
|
|
|
|
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
|
|
|
|
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
|
|
|
|
layout, framework::vectorize2int(output->dims()), groups);
|
|
|
|
layout, framework::vectorize<int>(output->dims()), groups);
|
|
|
|
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
|
|
|
|
// (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
|
|
|
|
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
|
|
|
|
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
|
|
|
|
layout, framework::vectorize2int(filter->dims()), groups);
|
|
|
|
layout, framework::vectorize<int>(filter->dims()), groups);
|
|
|
|
cudnnConvolutionDescriptor_t cudnn_conv_desc =
|
|
|
|
cudnnConvolutionDescriptor_t cudnn_conv_desc =
|
|
|
|
conv_desc.descriptor<T>(paddings, strides, dilations);
|
|
|
|
conv_desc.descriptor<T>(paddings, strides, dilations);
|
|
|
|
|
|
|
|
|
|
|
@ -148,13 +148,13 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
// Input: (N, M, H, W) or (N, M, D, H, W)
|
|
|
|
// Input: (N, M, H, W) or (N, M, D, H, W)
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
|
|
|
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
|
|
|
|
layout, framework::vectorize2int(input->dims()), groups);
|
|
|
|
layout, framework::vectorize<int>(input->dims()), groups);
|
|
|
|
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
|
|
|
|
// Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
|
|
|
|
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
|
|
|
|
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
|
|
|
|
layout, framework::vectorize2int(output_grad->dims()), groups);
|
|
|
|
layout, framework::vectorize<int>(output_grad->dims()), groups);
|
|
|
|
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
|
|
|
|
// Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w)
|
|
|
|
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
|
|
|
|
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
|
|
|
|
layout, framework::vectorize2int(filter->dims()), groups);
|
|
|
|
layout, framework::vectorize<int>(filter->dims()), groups);
|
|
|
|
|
|
|
|
|
|
|
|
cudnnConvolutionDescriptor_t cudnn_conv_desc =
|
|
|
|
cudnnConvolutionDescriptor_t cudnn_conv_desc =
|
|
|
|
conv_desc.descriptor<T>(paddings, strides, dilations);
|
|
|
|
conv_desc.descriptor<T>(paddings, strides, dilations);
|
|
|
|