diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h index 80dd18a8cb..32527af5c6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h @@ -27,7 +27,8 @@ namespace kernel { template class UnsortedSegmentSumGpuKernel : public GpuKernel { public: - UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {} + UnsortedSegmentSumGpuKernel() + : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1), is_null_input_(false) {} ~UnsortedSegmentSumGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -36,6 +37,9 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } T *input_addr = GetDeviceAddress(inputs, 0); S *indices_addr = GetDeviceAddress(inputs, 1); T *output_addr = GetDeviceAddress(outputs, 0); @@ -50,6 +54,12 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) override { auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shapes); + if (is_null_input_) { + MS_LOG(WARNING) << "UnsortedSegmentSum input is null"; + InitSizeLists(); + return true; + } auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); @@ -83,6 +93,7 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel { size_t input_dim1_; size_t output_dim0_; size_t output_dim1_; + bool is_null_input_; std::vector input_size_list_; std::vector output_size_list_;