|
|
|
@ -114,23 +114,7 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|
|
|
|
pad_height_ = GetAttr<int>(kernel_node, "pad");
|
|
|
|
|
pad_width_ = pad_height_;
|
|
|
|
|
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
|
|
|
|
|
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
|
|
|
|
|
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
|
|
|
|
|
if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!";
|
|
|
|
|
}
|
|
|
|
|
if (stride_ori[0] != 1 || stride_ori[1] != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!";
|
|
|
|
|
}
|
|
|
|
|
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "conv2d only support equal dilation, and dilation must be 4d!";
|
|
|
|
|
}
|
|
|
|
|
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!";
|
|
|
|
|
}
|
|
|
|
|
stride_ = stride_ori[2];
|
|
|
|
|
dilation_ = dilation_ori[2];
|
|
|
|
|
|
|
|
|
|
SetStrideAndDilation(kernel_node);
|
|
|
|
|
cudnnTensorDescriptor_t input_descriptor_real = nullptr;
|
|
|
|
|
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
|
|
|
|
|
SetPad(in_shape, kernel_node);
|
|
|
|
@ -277,6 +261,24 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|
|
|
|
conv_algorithm_ = perf_results.algo;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void SetStrideAndDilation(const CNodePtr &kernel_node) {
|
|
|
|
|
auto stride_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
|
|
|
|
|
auto dilation_ori = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "dilation");
|
|
|
|
|
if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!";
|
|
|
|
|
}
|
|
|
|
|
if (stride_ori[0] != 1 || stride_ori[1] != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!";
|
|
|
|
|
}
|
|
|
|
|
if (dilation_ori.size() != 4 || dilation_ori[2] != dilation_ori[3]) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "conv2d only support equal dilation, and dilation must be 4d!";
|
|
|
|
|
}
|
|
|
|
|
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!";
|
|
|
|
|
}
|
|
|
|
|
stride_ = stride_ori[2];
|
|
|
|
|
dilation_ = dilation_ori[2];
|
|
|
|
|
}
|
|
|
|
|
cudnnHandle_t cudnn_handle_;
|
|
|
|
|
cudnnTensorDescriptor_t input_desc_;
|
|
|
|
|
cudnnTensorDescriptor_t output_desc_;
|
|
|
|
|