From 098d588d7df17354e9a856fa8f106cff55999c68 Mon Sep 17 00:00:00 2001 From: danishnxt Date: Thu, 17 Dec 2020 15:53:45 -0500 Subject: [PATCH] UnsortedSegMin/Max output_shape validation fix --- .../gpu/arrays/unsorted_segment_max_gpu_kernel.h | 8 ++++---- .../gpu/arrays/unsorted_segment_min_gpu_kernel.h | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h index f08838ec71..4c651e3fb6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h @@ -71,10 +71,10 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { } else { MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 2"; } - auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); - if (value_count.size() != 1) { - MS_LOG(ERROR) << "For UnsortedSegmentMax, output shape incorrect rank. Expect Rank: 1, got Rank: " - << value_count.size() << "."; + if (output_shapes.size() < 1) { + MS_LOG(EXCEPTION) + << "For UnsortedSegmentMax, output shape incorrect rank. Expect Rank at least rank 1, got Rank: " + << output_shapes.size() << "."; } num_segments_ = output_shapes[0]; input_size_ = 1; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h index 6eebc557e2..43cc8ce017 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h @@ -65,10 +65,10 @@ class UnsortedSegmentMinGpuKernel : public GpuKernel { } else { MS_LOG(INFO) << "UnsortedSegmentMin Kernel Input count is 2"; } - auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); - if (value_count.size() != 1) { - MS_LOG(ERROR) << "For UnsortedSegmentMin, output shape incorrect rank. Expect Rank: 1, got Rank: " - << value_count.size() << "."; + if (output_shapes.size() < 1) { + MS_LOG(EXCEPTION) + << "For UnsortedSegmentMin, output shape incorrect rank. Expect Rank at least rank 1, got Rank: " + << output_shapes.size() << "."; } num_segments_ = output_shapes[0]; input_size_ = 1;