|
|
|
@ -46,8 +46,6 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|
|
|
|
pad_left_(0),
|
|
|
|
|
n_(0),
|
|
|
|
|
c_(0),
|
|
|
|
|
stride_(1),
|
|
|
|
|
dilation_(0),
|
|
|
|
|
group_(1),
|
|
|
|
|
is_null_input_(false),
|
|
|
|
|
input_size_(0),
|
|
|
|
@ -84,7 +82,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|
|
|
|
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
|
|
|
|
|
workspace_size_, &beta, padded_descriptor_, padded),
|
|
|
|
|
"ConvolutionBackwardData failed");
|
|
|
|
|
CalPadGrad(padded_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
|
|
|
|
|
CalPadGrad(input_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
|
|
|
|
|
old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
|
|
|
@ -129,8 +127,8 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|
|
|
|
pad_width_ = 0;
|
|
|
|
|
}
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
|
|
|
|
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_, stride_, dilation_, dilation_,
|
|
|
|
|
CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
|
|
|
|
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2],
|
|
|
|
|
dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
|
|
|
|
"cudnnSetConvolution2dDescriptor failed");
|
|
|
|
|
dx_desc_real = dx_desc_;
|
|
|
|
|
}
|
|
|
|
@ -229,10 +227,10 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_,
|
|
|
|
|
c_, old_height_ + pad_height_, old_width_ + pad_width_),
|
|
|
|
|
"cudnnSetTensor4dDescriptor failed");
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(
|
|
|
|
|
cudnnSetConvolution2dDescriptor(conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_, stride_,
|
|
|
|
|
dilation_, dilation_, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
|
|
|
|
"cudnnSetConvolution2dDescriptor failed");
|
|
|
|
|
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor(
|
|
|
|
|
conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1],
|
|
|
|
|
dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
|
|
|
|
"cudnnSetConvolution2dDescriptor failed");
|
|
|
|
|
}
|
|
|
|
|
void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) {
|
|
|
|
|
if (group_ > 1 || CUDNN_MAJOR < 7) {
|
|
|
|
@ -275,19 +273,17 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|
|
|
|
"SetTensor4dDescriptor failed");
|
|
|
|
|
}
|
|
|
|
|
void SetStrideAndDilation(const CNodePtr &kernel_node) {
|
|
|
|
|
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
|
|
|
|
|
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
|
|
|
|
|
if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel only support equal stride, and stride must be 2d!";
|
|
|
|
|
stride_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
|
|
|
|
|
dilation_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
|
|
|
|
|
if (stride_.size() != 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's stride must be 2d!";
|
|
|
|
|
}
|
|
|
|
|
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel only support equal dilation, and dilation must be 4d!";
|
|
|
|
|
if (dilation_.size() != 4) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's dilation must be 4d!";
|
|
|
|
|
}
|
|
|
|
|
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
|
|
|
|
|
if (dilation_[0] != 1 || dilation_[1] != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel dilation only support 1 in N axis and C axis!";
|
|
|
|
|
}
|
|
|
|
|
stride_ = stride_ori[0];
|
|
|
|
|
dilation_ = dilation_ori[2];
|
|
|
|
|
}
|
|
|
|
|
cudnnHandle_t cudnn_handle_;
|
|
|
|
|
cudnnFilterDescriptor_t w_desc_;
|
|
|
|
@ -309,8 +305,8 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|
|
|
|
int pad_left_;
|
|
|
|
|
int n_;
|
|
|
|
|
int c_;
|
|
|
|
|
int stride_;
|
|
|
|
|
int dilation_;
|
|
|
|
|
std::vector<int> stride_;
|
|
|
|
|
std::vector<int> dilation_;
|
|
|
|
|
int group_;
|
|
|
|
|
bool is_null_input_;
|
|
|
|
|
size_t input_size_;
|
|
|
|
|