diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h index cc6293e9bc..1ee420c075 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h @@ -216,7 +216,7 @@ class ArrayReduceGpuKernel : public GpuKernel { std::vector inputA; std::vector outputC_shape = output_shape; const int split_dim = 4; - + CHECK_TENSOR_SIZE(input_shape); if (input_shape.size() <= split_dim) { ShapeNdTo4d(input_shape, &inputA); CHECK_CUDNN_RET_WITH_EXCEPT( diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h index 64749262b4..7c66779274 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -76,6 +76,7 @@ class ActivationGpuFwdKernel : public GpuKernel { InitSizeLists(); return true; } + CHECK_TENSOR_SIZE(input_shape); std::vector shape; double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 6.0 : 0.0; if (mode_ == CUDNN_ACTIVATION_ELU) { @@ -85,7 +86,6 @@ class ActivationGpuFwdKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, coef), "cudnnSetActivationDescriptor failed"); - const int split_dim = 4; if (input_shape.size() <= split_dim) { ShapeNdTo4d(input_shape, &shape); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h index 30fabdb452..7af05b9642 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h @@ -84,6 +84,7 @@ class ActivationGradGpuKernel : public GpuKernel { InitSizeLists(); return true; } + CHECK_TENSOR_SIZE(input_shape); std::vector shape; double coef = (mode_ == CUDNN_ACTIVATION_CLIPPED_RELU) ? 5.999999 : 0.0; if (mode_ == CUDNN_ACTIVATION_ELU) coef = 1.0; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h index 8e499ff3dc..f6c3205df6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h @@ -132,6 +132,7 @@ class BatchNormGpuKernel : public GpuKernel { if (format_attr == kOpFormat_NHWC) { format = kOpFormat_NHWC; } + CHECK_TENSOR_SIZE(shape); SetTensorDescriptor(format, shape); InitSizeLists(); return true; @@ -254,7 +255,6 @@ class BatchNormGpuKernel : public GpuKernel { width = SizeToInt(shape[3]); cudnn_format = CUDNN_TENSOR_NCHW; } - CHECK_CUDNN_RET_WITH_EXCEPT( kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), "Set x desc failed"); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h index 9754414e55..8c0a4ca7d4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h @@ -155,6 +155,7 @@ class BatchNormGradGpuKernel : public GpuKernel { format = kOpFormat_NHWC; } beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; + CHECK_TENSOR_SIZE(shape); SetTensorDescriptor(format, shape); InitSizeLists(); is_train_ = GetAttr(kernel_node, "is_training"); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h index f3db6d7658..b7359266ff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h @@ -99,6 +99,7 @@ class Conv2dGpuFwdKernel : public GpuKernel { InitSizeLists(); return true; } + CHECK_TENSOR_SIZE(in_shape); SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_); if (data_format_ == kOpFormat_NHWC) { compute_format_ = CUDNN_TENSOR_NHWC; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h index caa692f80c..40615c472b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h @@ -118,6 +118,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { InitSizeLists(); return true; } + CHECK_TENSOR_SIZE(in_shape); data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); format_attr_ = GetAttr(kernel_node, "format"); if (format_attr_ == kOpFormat_NHWC) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h index 795154c174..fd03492d62 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h @@ -133,6 +133,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { ShapeNCHW2NHWC(&input_shape); } } + CHECK_TENSOR_SIZE(input_shape); SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); Set4DDesc(dy_shape, input_shape, filter_shape); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h index 95fc16cf95..80ed403732 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h @@ -98,6 +98,7 @@ class Im2ColGpuFwdKernel : public GpuKernel { InitSizeLists(); return true; } + CHECK_TENSOR_SIZE(in_shape); Set4DDesc(in_shape, filter_shape, output_shape); CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetConvolutionGroupCount(conv_desc_, 1), "cudnnSetConvGroupCount failed"); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h index 98883ac447..dad2ea6d6b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h @@ -134,6 +134,7 @@ class InstanceNormGpuKernel : public GpuKernel { InitSizeLists(); return true; } + CHECK_TENSOR_SIZE(input_shape_); SetTensorDescriptor(); InitSizeLists(); return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h index bb4d019ab8..bac4b53153 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h @@ -131,6 +131,7 @@ class InstanceNormGradGpuKernel : public GpuKernel { InitSizeLists(); return true; } + CHECK_TENSOR_SIZE(input_shape_); beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; SetTensorDescriptor(); InitSizeLists(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h index 78ec17b743..311b20fd1d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_gpu_kernel.h @@ -110,6 +110,7 @@ class L2NormalizeGpuKernel : public GpuKernel { InitSizeLists(); return true; } + CHECK_TENSOR_SIZE(inputA_shape); if (inputA_shape.size() > MAX_DIMS) { MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7"; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h index 451303d79b..11fa9fb86d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/l2normalize_grad_gpu_kernel.h @@ -262,7 +262,8 @@ class L2NormalizeGradGpuKernel : public GpuKernel { std::vector inputA; std::vector outputC_shape = output_shape; constexpr int split_dim = 4; - + CHECK_TENSOR_SIZE(input_shape); + CHECK_TENSOR_SIZE(output_shape); if (input_shape.size() <= split_dim) { ShapeNdTo4d(input_shape, &inputA); CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h index 9e5060f951..dd46b9f7df 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h @@ -93,6 +93,7 @@ class PoolingGpuFwdKernel : public GpuKernel { InitSizeLists(); return true; } + CHECK_TENSOR_SIZE(input_shape); SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); const int nbDims = 4; int dimA[4]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h index e385d74c2e..9d8f018251 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -96,6 +96,7 @@ class PoolingGradGpuKernel : public GpuKernel { InitSizeLists(); return false; } + CHECK_TENSOR_SIZE(input_shape); SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format); SetDimA(input_shape, dimA, 4, data_format); SetStrideA(input_shape, strideAin, 4, data_format); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h index 680f34f30e..c30aa45236 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h @@ -136,6 +136,7 @@ class BatchNormFoldGpuKernel : public GpuKernel { << ", but BatchNormFold GpuKernel OP needs 4DTensor input."; return false; } + CHECK_TENSOR_SIZE(input_shape); batch_ = input_shape[0]; channel_ = input_shape[1]; height_ = input_shape[2]; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h index d72485be1b..6631f834ac 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h @@ -189,7 +189,7 @@ namespace gpu { #define VARIABLE_NOT_USED(var) \ { (void)(var); } -inline bool CheckNullInput(std::vector input_shape) { +inline bool CheckNullInput(const std::vector &input_shape) { // If input_shape.size() == 0, it means a scalar input; If input_shape.size() != 0 and input_shape contains 0, // it means a null input. Just return a null output. if (input_shape.size() != 0) { @@ -201,6 +201,19 @@ inline bool CheckNullInput(std::vector input_shape) { } #define CHECK_NULL_INPUT(input_shape) mindspore::device::gpu::CheckNullInput(input_shape) +// The tensor size is limited to 2G by cudnn. +inline void CheckTensorSize(const std::vector &shape) { + size_t total_size = 1; + for (auto i : shape) { + total_size *= i; + } + if (total_size >= 2147483648) { + MS_EXCEPTION(ValueError) << "The total size of the tensor exceeds the max_limit of 2 Giga-elements, which is " + << total_size << "elements (" << shape << ")."; + } +} +#define CHECK_TENSOR_SIZE(shape) mindspore::device::gpu::CheckTensorSize(shape) + #define CHECK_CURAND_RET_WITH_EXCEPT(expression, message) \ { \ curandStatus_t status = (expression); \