From 9996e0d4d2fe6f7dffc59190d3061216b929e5ea Mon Sep 17 00:00:00 2001 From: VectorSL Date: Thu, 14 May 2020 12:59:37 +0800 Subject: [PATCH] gpu update shape infer --- .../ccsrc/device/gpu/kernel_info_setter.cc | 13 +++++-- .../gpu/arrays/array_reduce_gpu_kernel.h | 39 +++++++------------ .../kernel/gpu/arrays/slice_gpu_kernel.h | 12 +----- .../kernel/gpu/arrays/slice_grad_gpu_kernel.h | 11 +----- mindspore/ccsrc/kernel/gpu/gpu_kernel.h | 19 ++++++++- .../kernel/gpu/math/tensoradd_gpu_kernel.h | 26 +++++-------- .../ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h | 4 +- .../kernel/gpu/nn/pooling_grad_gpu_kernel.h | 4 +- .../ccsrc/kernel/gpu/nn/relu_gpu_kernel.h | 35 ++++++----------- .../ccsrc/kernel/gpu/nn/relu_grad_kernel.h | 24 ++++++------ 10 files changed, 81 insertions(+), 106 deletions(-) diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc index 6ccb4c8cde..2ba154b87b 100644 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc @@ -184,10 +184,17 @@ void SetKernelInfo(const CNodePtr &kernel_node) { if (!result) { auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - + std::string build_type = "in ["; + std::for_each(std::begin(inputs_type), std::end(inputs_type), + [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); + build_type += "] out ["; + std::for_each(std::begin(outputs_type), std::end(outputs_type), + [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); + build_type += "]"; auto supported_type_lists = SupportedTypeList(kernel_node); - MS_LOG(EXCEPTION) << "Select GPU kernel op[" << kernel_name - << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists; + MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name + << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists + << ", but get " << build_type; } builder->SetKernelType(kernel_type); builder->SetProcessor(kernel::Processor::CUDA); diff --git a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h index a12aa17448..c8410c419d 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h @@ -178,46 +178,33 @@ class ArrayReduceGpuKernel : public GpuKernel { return; } void InferInAndOutDesc(const std::vector &input_shape, const std::vector &output_shape) { - std::vector inputA_shape = input_shape; + std::vector inputA; std::vector outputC_shape = output_shape; - std::vector real_input_shape; - int shapeA_n, shapeA_c, shapeA_h, shapeA_w; - shapeA_n = inputA_shape.size() < 4 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 4]); - shapeA_c = inputA_shape.size() < 3 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 3]); - shapeA_h = inputA_shape.size() < 2 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 2]); - shapeA_w = inputA_shape.size() == 0 ? 1 : SizeToInt(inputA_shape[inputA_shape.size() - 1]); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, shapeA_n, - shapeA_c, shapeA_h, shapeA_w), + ShapeNdTo4d(input_shape, &inputA); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0], + inputA[1], inputA[2], inputA[3]), "cudnnSetTensor4dDescriptor failed"); - int shapeC_n, shapeC_c, shapeC_h, shapeC_w; if (axis_[0] == -1) { - shapeC_n = 1; - shapeC_c = 1; - shapeC_h = 1; - shapeC_w = 1; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, - shapeC_n, shapeC_c, shapeC_h, shapeC_w), - "cudnnSetTensor4dDescriptor failed"); - if (shapeA_n == shapeC_n && shapeA_c == shapeC_c && shapeA_h == shapeC_h && shapeA_w == shapeC_w) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1), + "cudnnSetTensor4dDescriptor failed"); + if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) { all_match_ = true; } return; } - if (!keep_dims_) { for (auto i : axis_) { (void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); } } - shapeC_n = outputC_shape.size() < 4 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 4]); - shapeC_c = outputC_shape.size() < 3 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 3]); - shapeC_h = outputC_shape.size() < 2 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 2]); - shapeC_w = outputC_shape.size() == 0 ? 1 : SizeToInt(outputC_shape[outputC_shape.size() - 1]); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, shapeC_n, - shapeC_c, shapeC_h, shapeC_w), + std::vector outputC; + ShapeNdTo4d(outputC_shape, &outputC); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, + outputC[0], outputC[1], outputC[2], outputC[3]), "cudnnSetTensor4dDescriptor failed"); - if (shapeA_n == shapeC_n && shapeA_c == shapeC_c && shapeA_h == shapeC_h && shapeA_w == shapeC_w) { + if (inputA == outputC) { all_match_ = true; } return; diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h index 99b8372008..f71ec23d2e 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h @@ -52,15 +52,7 @@ class SliceGpuFwdKernel : public GpuKernel { return false; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - int shape_n = input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4]); - int shape_c = input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]); - int shape_h = input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]); - int shape_w = SizeToInt(input_shape[input_shape.size() - 1]); - input_shape_.push_back(shape_n); - input_shape_.push_back(shape_c); - input_shape_.push_back(shape_h); - input_shape_.push_back(shape_w); - + ShapeNdTo4d(input_shape, &input_shape_); auto strides = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"); if (strides) { strides_ = GetAttr>(kernel_node, "strides"); @@ -89,7 +81,7 @@ class SliceGpuFwdKernel : public GpuKernel { } } - input_size_ = IntToSize(shape_n * shape_c * shape_h * shape_w) * sizeof(T); + input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); auto out_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_size_ = sizeof(T); diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h index 9660e1dd9b..80eef23112 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h @@ -66,19 +66,12 @@ class SliceGradGpuKernel : public GpuKernel { size_ = GetAttr>(kernel_node, "end"); } else { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - input_shape_.push_back(input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4])); - input_shape_.push_back(input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3])); - input_shape_.push_back(input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2])); - input_shape_.push_back(SizeToInt(input_shape[input_shape.size() - 1])); + ShapeNdTo4d(input_shape, &input_shape_); size_ = GetAttr>(kernel_node, "size"); } auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - dy_shape_.push_back(dy_shape.size() < 4 ? 1 : SizeToInt(dy_shape[dy_shape.size() - 4])); - dy_shape_.push_back(dy_shape.size() < 3 ? 1 : SizeToInt(dy_shape[dy_shape.size() - 3])); - dy_shape_.push_back(dy_shape.size() < 2 ? 1 : SizeToInt(dy_shape[dy_shape.size() - 2])); - dy_shape_.push_back(SizeToInt(dy_shape[dy_shape.size() - 1])); - + ShapeNdTo4d(dy_shape, &dy_shape_); begin_ = GetAttr>(kernel_node, "begin"); DealParam(); input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/gpu_kernel.h index de7176eff7..9f8090451f 100644 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/gpu_kernel.h @@ -39,7 +39,6 @@ class GpuKernel : public KernelMod { virtual void InitSizeLists() = 0; template - inline T *GetDeviceAddress(const std::vector &addr_list, size_t index) { if (index >= addr_list.size()) { MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; @@ -62,6 +61,24 @@ class GpuKernel : public KernelMod { } return GetValue(attr); } + // expand Nd Shape to 4d (N in [0,4]) + void ShapeNdTo4d(const std::vector &src, std::vector *dst) { + dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4])); + dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3])); + dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2])); + dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1])); + } + + inline void CheckBroadcast4TensorOp(const std::vector &A, const std::vector &B, + const std::vector &Out) { + if (A != Out && B != Out) { + MS_EXCEPTION(ValueError) + << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n" + "InputA must match the corresponding dimension of the destination tensor outC, and each " + "dimension of the inputB " + "must match the corresponding dimension of outC or must be equal to 1."; + } + } }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h index ec954173f9..52480b8c70 100644 --- a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h @@ -87,32 +87,24 @@ class TensorAddGpuFwdKernel : public GpuKernel { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shapeB = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - if (input_shape != output_shape && input_shapeB != output_shape) { - MS_LOG(ERROR) << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n" - "InputA must match the corresponding dimension of the destination tensor outC, and each " - "dimension of the inputB " - "must match the corresponding dimension of outC or must be equal to 1."; - return false; - } is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_shapeB); if (is_null_input_) { MS_LOG(WARNING) << "TensorAddGpuFwdKernel input is null"; InitSizeLists(); return true; } - int shape_n = input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4]); - int shape_c = input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]); - int shape_h = input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]); - int shape_w = input_shape.size() == 0 ? 1 : SizeToInt(input_shape[input_shape.size() - 1]); + std::vector shapeA; + std::vector shapeB; + std::vector shapeOut; + ShapeNdTo4d(input_shape, &shapeA); + ShapeNdTo4d(input_shapeB, &shapeB); + ShapeNdTo4d(output_shape, &shapeOut); + CheckBroadcast4TensorOp(shapeA, shapeB, shapeOut); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape_n, shape_c, shape_h, shape_w), + shapeA[0], shapeA[1], shapeA[2], shapeA[3]), "cudnnSetTensor4dDescriptor failed"); - int shapeB_n = input_shapeB.size() < 4 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 4]); - int shapeB_c = input_shapeB.size() < 3 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 3]); - int shapeB_h = input_shapeB.size() < 2 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 2]); - int shapeB_w = input_shapeB.size() == 0 ? 1 : SizeToInt(input_shapeB[input_shapeB.size() - 1]); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputB_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shapeB_n, shapeB_c, shapeB_h, shapeB_w), + shapeB[0], shapeB[1], shapeB[2], shapeB[3]), "cudnnSetTensor4dDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT( diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h index 67a705cdc1..2446c22950 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h @@ -107,8 +107,8 @@ class PoolingGpuFwdKernel : public GpuKernel { SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), "cudnnSetTensor4dDescriptor failed"); auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); - int window_height = window[3]; - int window_width = window[2]; + int window_height = window[2]; + int window_width = window[3]; stride_ = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); SetPoolingMode(kernel_node); if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h index b2dc7d5e67..535f96bbbf 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h @@ -101,8 +101,8 @@ class PoolingGradGpuFwdKernel : public GpuKernel { return false; } auto window = GetAttr>(kernel_node, "ksize"); - int window_height = window[3]; - int window_width = window[2]; + int window_height = window[2]; + int window_width = window[3]; SetPoolingMode(kernel_node); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); diff --git a/mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.h index 7931dbaf1b..d88efd3c7a 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/relu_gpu_kernel.h @@ -31,8 +31,7 @@ class ReLUGpuFwdKernel : public GpuKernel { : cudnn_handle_(nullptr), activation_desc_(nullptr), mode_(CUDNN_ACTIVATION_RELU), - input_descriptor_(nullptr), - output_descriptor_(nullptr), + data_descriptor_(nullptr), is_null_input_(false), cudnn_data_type_(CUDNN_DATA_FLOAT), input_size_(0), @@ -53,8 +52,8 @@ class ReLUGpuFwdKernel : public GpuKernel { const float alpha = 1; const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, input_descriptor_, - input, &beta, output_descriptor_, output), + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, + &beta, data_descriptor_, output), "ReLUGpuFwdKernel failed"); return true; @@ -75,18 +74,12 @@ class ReLUGpuFwdKernel : public GpuKernel { return true; } mode_ = CUDNN_ACTIVATION_RELU; - int batch_size = input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4]); - int channel_size = input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]); - int height = input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]); - int width = input_shape.size() == 0 ? 1 : SizeToInt(input_shape[input_shape.size() - 1]); - + std::vector shape; + ShapeNdTo4d(input_shape, &shape); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), "SetActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - batch_size, channel_size, height, width), - "SetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - batch_size, channel_size, height, width), + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + shape[0], shape[1], shape[2], shape[3]), "SetTensor4dDescriptor failed"); InitSizeLists(); return true; @@ -95,18 +88,16 @@ class ReLUGpuFwdKernel : public GpuKernel { protected: void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), "cudnnCreateActivationDescriptor failed"); } void InitSizeLists() override { if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_descriptor_, &output_size_), + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), "cudnnGetTensorSizeInBytes failed"); + output_size_ = input_size_; } input_size_list_.push_back(input_size_); output_size_list_.push_back(output_size_); @@ -116,15 +107,13 @@ class ReLUGpuFwdKernel : public GpuKernel { void DestroyResource() noexcept { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), "cudnnDestroyActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); } cudnnHandle_t cudnn_handle_; cudnnActivationDescriptor_t activation_desc_; cudnnActivationMode_t mode_; - cudnnTensorDescriptor_t input_descriptor_; - cudnnTensorDescriptor_t output_descriptor_; + cudnnTensorDescriptor_t data_descriptor_; bool is_null_input_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.h index 713c08c654..e93dc31f80 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/relu_grad_kernel.h @@ -31,7 +31,7 @@ class ReluGradGpuFwdKernel : public GpuKernel { : cudnn_handle_(nullptr), activation_desc_(nullptr), mode_(CUDNN_ACTIVATION_RELU), - input_descriptor_(nullptr), + data_descriptor_(nullptr), is_null_input_(false), cudnn_data_type_(CUDNN_DATA_FLOAT), input_size_(0) {} @@ -52,8 +52,8 @@ class ReluGradGpuFwdKernel : public GpuKernel { const float alpha = 1; const float beta = 0; CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, input_descriptor_, y, input_descriptor_, dy, - input_descriptor_, y, &beta, input_descriptor_, dx), + cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, + data_descriptor_, y, &beta, data_descriptor_, dx), "cudnnActivationBackward failed"); return true; @@ -74,14 +74,12 @@ class ReluGradGpuFwdKernel : public GpuKernel { InitSizeLists(); return true; } - int batch_size = input_shape.size() < 4 ? 1 : SizeToInt(input_shape[input_shape.size() - 4]); - int channel_size = input_shape.size() < 3 ? 1 : SizeToInt(input_shape[input_shape.size() - 3]); - int height = input_shape.size() < 2 ? 1 : SizeToInt(input_shape[input_shape.size() - 2]); - int width = input_shape.size() == 0 ? 1 : SizeToInt(input_shape[input_shape.size() - 1]); + std::vector shape; + ShapeNdTo4d(input_shape, &shape); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0), "SetActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - batch_size, channel_size, height, width), + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + shape[0], shape[1], shape[2], shape[3]), "SetTensor4dDescriptor failed"); InitSizeLists(); @@ -91,13 +89,13 @@ class ReluGradGpuFwdKernel : public GpuKernel { protected: void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), "cudnnCreateActivationDescriptor failed"); } void InitSizeLists() override { if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), "cudnnGetTensorSizeInBytes failed"); } input_size_list_.push_back(input_size_); @@ -109,13 +107,13 @@ class ReluGradGpuFwdKernel : public GpuKernel { void DestroyResource() noexcept { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), "cudnnDestroyActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); } cudnnHandle_t cudnn_handle_; cudnnActivationDescriptor_t activation_desc_; cudnnActivationMode_t mode_; - cudnnTensorDescriptor_t input_descriptor_; + cudnnTensorDescriptor_t data_descriptor_; bool is_null_input_; std::vector input_size_list_; std::vector output_size_list_;