|
|
|
@ -142,10 +142,14 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|
|
|
|
}
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
|
|
|
|
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_,
|
|
|
|
|
CUDNN_CROSS_CORRELATION, cudnn_data_type_),
|
|
|
|
|
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
|
|
|
|
"cudnnSetConvolution2dDescriptor failed");
|
|
|
|
|
dx_desc_real = dx_desc_;
|
|
|
|
|
}
|
|
|
|
|
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH),
|
|
|
|
|
"cudnnSetConvolutionMathType failed.")
|
|
|
|
|
}
|
|
|
|
|
SelectAlgorithm(dx_desc_real);
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
return true;
|
|
|
|
@ -239,7 +243,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|
|
|
|
"cudnnSetTensor4dDescriptor failed");
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
|
|
|
|
cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_,
|
|
|
|
|
dilation_, dilation_, CUDNN_CROSS_CORRELATION, cudnn_data_type_),
|
|
|
|
|
dilation_, dilation_, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
|
|
|
|
"cudnnSetConvolution2dDescriptor failed");
|
|
|
|
|
}
|
|
|
|
|
void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) {
|
|
|
|
@ -258,6 +262,9 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|
|
|
|
"cudnnGetConvolutionBackwardDataAlgorithm_v7 failed");
|
|
|
|
|
algo_ = perf_results.algo;
|
|
|
|
|
}
|
|
|
|
|
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
|
|
|
|
|
algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void GetInputShape(const CNodePtr &kernel_node, std::vector<int> *input_shape) {
|
|
|
|
|
auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast<ValueTuplePtr>()->value();
|
|
|
|
|