diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc index 238c2df02a..bc4df65f4a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.cc @@ -30,7 +30,14 @@ MS_REG_GPU_KERNEL_ONE( UnsortedSegmentMax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), UnsortedSegmentMaxGpuKernel, int) -// Dynamic Mode +// Dynamic Mode - registered for int32/int64 3rd input +MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentMaxGpuKernel, float) MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, KernelAttr() .AddInputAttr(kNumberTypeFloat32) @@ -38,6 +45,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat32), UnsortedSegmentMaxGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + UnsortedSegmentMaxGpuKernel, half) MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, KernelAttr() .AddInputAttr(kNumberTypeFloat16) @@ -45,6 +59,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), UnsortedSegmentMaxGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentMaxGpuKernel, int) MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax, KernelAttr() .AddInputAttr(kNumberTypeInt32) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h index 9350ebca81..2c6fa1a177 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h @@ -69,7 +69,11 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { } else { MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 2"; } - + auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); + if (value_count.size() != 1) { + MS_LOG(ERROR) << "For UnsortedSegmentMax, output shape incorrect rank. Expect Rank: 1, got Rank: " + << value_count.size() << "."; + } num_segments_ = output_shapes[0]; input_size_ = 1; for (size_t i = 0; i < input_shapes.size(); i++) { @@ -117,7 +121,7 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { } private: - int num_segments_; + int64_t num_segments_; size_t inner_size_; size_t outer_size_; size_t input_size_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.cc index c067a7c50e..4d35565174 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.cc @@ -30,7 +30,14 @@ MS_REG_GPU_KERNEL_ONE( UnsortedSegmentMin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), UnsortedSegmentMinGpuKernel, int) -// Dynamic Mode +// Dynamic Mode - registered for int32/int64 3rd input +MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentMinGpuKernel, float) MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, KernelAttr() .AddInputAttr(kNumberTypeFloat32) @@ -38,6 +45,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat32), UnsortedSegmentMinGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + UnsortedSegmentMinGpuKernel, half) MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, KernelAttr() .AddInputAttr(kNumberTypeFloat16) @@ -45,6 +59,13 @@ MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), UnsortedSegmentMinGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentMinGpuKernel, int) MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMin, KernelAttr() .AddInputAttr(kNumberTypeInt32) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h index 0d7b4a6abb..6eebc557e2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h @@ -65,7 +65,11 @@ class UnsortedSegmentMinGpuKernel : public GpuKernel { } else { MS_LOG(INFO) << "UnsortedSegmentMin Kernel Input count is 2"; } - + auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); + if (value_count.size() != 1) { + MS_LOG(ERROR) << "For UnsortedSegmentMin, output shape incorrect rank. Expect Rank: 1, got Rank: " + << value_count.size() << "."; + } num_segments_ = output_shapes[0]; input_size_ = 1; for (size_t i = 0; i < input_shapes.size(); i++) { @@ -113,7 +117,7 @@ class UnsortedSegmentMinGpuKernel : public GpuKernel { } private: - int num_segments_; + int64_t num_segments_; size_t inner_size_; size_t outer_size_; size_t input_size_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu index 954611ffe7..c1eb49f00b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cu @@ -18,8 +18,8 @@ #include template -__global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, - size_t inner_size, bool fp16_flag, T init_K, T *output) { +__global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, + size_t outer_size, size_t inner_size, bool fp16_flag, T init_K, T *output) { if (fp16_flag) { init_K = __int2half_rd(-65504); // min value representable by float16 } @@ -57,7 +57,7 @@ __global__ void UnsortedSegmentMax(const T *input, const int *segment_ids, const } template -void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, +void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, T *output, cudaStream_t stream) { int size = (inner_size * KWARPSIZE * num_segments); bool fp16_flag = false; @@ -71,9 +71,9 @@ void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num return; } -template void CalUnsortedSegmentMax(const float *input, const int *segment_ids, const int num_segments, +template void CalUnsortedSegmentMax(const float *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, float *output, cudaStream_t stream); -template void CalUnsortedSegmentMax(const half *input, const int *segment_ids, const int num_segments, +template void CalUnsortedSegmentMax(const half *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); -template void CalUnsortedSegmentMax(const int *input, const int *segment_ids, const int num_segments, +template void CalUnsortedSegmentMax(const int *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh index de80b74007..caab13ce65 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_max.cuh @@ -22,9 +22,8 @@ // Setting warp size to sync data across threads #define KWARPSIZE 32 - template -void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, +void CalUnsortedSegmentMax(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, T *output, cudaStream_t stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MAX_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cu index 7e3de1a84c..fe8958cb18 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cu @@ -17,19 +17,19 @@ #include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh" #include -template +template __device__ __forceinline__ void max_val_init(T *init_val) { *init_val = std::numeric_limits::max(); } // Handle fp16 differently for assignment -template<> +template <> __device__ __forceinline__ void max_val_init(half *init_val) { *init_val = __int2half_rd(65504); // Max value for Half } template -__global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, - size_t inner_size, T init_K, T *output) { +__global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const int64_t num_segments, + size_t outer_size, size_t inner_size, T init_K, T *output) { max_val_init(&init_K); for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < KWARPSIZE * num_segments * inner_size; t_idx += blockDim.x * gridDim.x) { @@ -62,18 +62,18 @@ __global__ void UnsortedSegmentMin(const T *input, const int *segment_ids, const } template -void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, +void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, T *output, cudaStream_t stream) { int size = (inner_size * KWARPSIZE * num_segments); - T init_K = std::numeric_limits::lowest(); // only init here - overwritten later + T init_K = std::numeric_limits::lowest(); // only init here - overwritten later UnsortedSegmentMin<<>>(input, segment_ids, num_segments, outer_size, inner_size, init_K, output); return; } -template void CalUnsortedSegmentMin(const float *input, const int *segment_ids, const int num_segments, +template void CalUnsortedSegmentMin(const float *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, float *output, cudaStream_t stream); -template void CalUnsortedSegmentMin(const half *input, const int *segment_ids, const int num_segments, +template void CalUnsortedSegmentMin(const half *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, half *output, cudaStream_t stream); -template void CalUnsortedSegmentMin(const int *input, const int *segment_ids, const int num_segments, +template void CalUnsortedSegmentMin(const int *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, int *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh index 01b1836b67..45d5de6dff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_min.cuh @@ -23,6 +23,6 @@ // Setting warp size to sync data across threads #define KWARPSIZE 32 template -void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int num_segments, size_t outer_size, +void CalUnsortedSegmentMin(const T *input, const int *segment_ids, const int64_t num_segments, size_t outer_size, size_t inner_size, T *output, cudaStream_t stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_MIN_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h index 106744d2e0..169bbd7e3f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h @@ -29,7 +29,7 @@ namespace kernel { template class PadGpuFwdKernel : public GpuKernel { public: - PadGpuFwdKernel() : shape_size_(0), temp(0), input_size_(0), output_size_(0), workspace_size_(0) {} + PadGpuFwdKernel() : shape_size_(0), temp(0), input_size_(1), output_size_(1), workspace_size_(0) {} ~PadGpuFwdKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -53,13 +53,11 @@ class PadGpuFwdKernel : public GpuKernel { } bool Init(const CNodePtr &kernel_node) override { - // check number of inputs -> should be 1 size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { MS_LOG(ERROR) << "Input number is " << input_num << ", but Pad needs 1 input."; return false; } - // check number of output -> should be 1 size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { MS_LOG(ERROR) << "Output number is " << output_num << ", but Pad needs 1 output."; @@ -67,8 +65,7 @@ class PadGpuFwdKernel : public GpuKernel { } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); shape_size_ = input_shape.size(); - // shape adjustement -> from 2d/3d to 4d to standardize - if (shape_size_ == 4) { + if (shape_size_ == 4) { // shape adjustement from 2d/3d to 4d } else if (shape_size_ == 3) { auto it = input_shape.begin(); input_shape.insert(it, 1); // batch padding @@ -87,8 +84,7 @@ class PadGpuFwdKernel : public GpuKernel { [](const int64_t &value) { return static_cast(value); }); return shape; }); - // shape adjustement -> from 2d/3d to 4d to standardize - if (paddings.size() == 4) { + if (paddings.size() == 4) { // shape adjustement from 2d/3d to 4d } else if (paddings.size() == 3) { auto it = paddings.begin(); paddings.insert(it, 1, {0, 0}); // batch padding @@ -96,13 +92,11 @@ class PadGpuFwdKernel : public GpuKernel { auto it = paddings.begin(); paddings.insert(it, 2, {0, 0}); // channel padding } - input_size_ = 1; for (size_t i = 0; i < shape_size_; i++) { input_size_ *= input_shape[i]; input_shape_.push_back(input_shape[i]); } input_size_ *= sizeof(T); - output_size_ = 1; for (size_t i = 0; i < shape_size_; i++) { temp = input_shape[i] + (paddings[i][0] + paddings[i][1]); // compute new dim size output_size_ *= temp; diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 3a86daaace..45c0c78877 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -227,10 +227,18 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri MS_EXCEPTION_IF_NULL(num_segments_value_ptr); auto num_segments_tensor = num_segments_value_ptr->cast(); MS_EXCEPTION_IF_NULL(num_segments_tensor); - num_segments_value = *static_cast(num_segments_tensor->data_c()); + if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { + num_segments_value = *static_cast(num_segments_tensor->data_c()); + } else { + num_segments_value = *static_cast(num_segments_tensor->data_c()); + } } else if (args_spec_list[2]->isa()) { // num_segments is Scalar auto num_segments = CheckArg(op_name, args_spec_list, 2); - num_segments_value = GetValue(num_segments->BuildValue()); + if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { + num_segments_value = GetValue(num_segments->BuildValue()); + } else { + num_segments_value = GetValue(num_segments->BuildValue()); + } } else { MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentSum"; } @@ -300,10 +308,19 @@ AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const Pri MS_EXCEPTION_IF_NULL(num_segments_value_ptr); auto num_segments_tensor = num_segments_value_ptr->cast(); MS_EXCEPTION_IF_NULL(num_segments_tensor); - num_segments_value = *static_cast(num_segments_tensor->data_c()); + if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { + num_segments_value = *static_cast(num_segments_tensor->data_c()); + } else { + num_segments_value = *static_cast(num_segments_tensor->data_c()); + } + // num_segments_value = *static_cast(num_segments_tensor->data_c()); } else if (args_spec_list[2]->isa()) { // num_segments is Scalar auto num_segments = CheckArg(op_name, args_spec_list, 2); - num_segments_value = GetValue(num_segments->BuildValue()); + if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { + num_segments_value = GetValue(num_segments->BuildValue()); + } else { + num_segments_value = GetValue(num_segments->BuildValue()); + } } else { MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMax"; } @@ -368,10 +385,18 @@ AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const Pri MS_EXCEPTION_IF_NULL(num_segments_value_ptr); auto num_segments_tensor = num_segments_value_ptr->cast(); MS_EXCEPTION_IF_NULL(num_segments_tensor); - num_segments_value = *static_cast(num_segments_tensor->data_c()); + if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { + num_segments_value = *static_cast(num_segments_tensor->data_c()); + } else { + num_segments_value = *static_cast(num_segments_tensor->data_c()); + } } else if (args_spec_list[2]->isa()) { // num_segments is Scalar auto num_segments = CheckArg(op_name, args_spec_list, 2); - num_segments_value = GetValue(num_segments->BuildValue()); + if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) { + num_segments_value = GetValue(num_segments->BuildValue()); + } else { + num_segments_value = GetValue(num_segments->BuildValue()); + } } else { MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMin"; } diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 245cbae48b..53131174b2 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1893,8 +1893,10 @@ class UnsortedSegmentSum(PrimitiveWithInfer): validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name) validator.check(f'rank of input_x', len(x_shp), 'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name) - for i, value in enumerate(segment_ids_shp): - validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name) + if (not -1 in x_shp and not -1 in segment_ids_shp): + # only validate when both shapes fully known + for i, value in enumerate(segment_ids_shp): + validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name) num_segments_v = num_segments['value'] num_segments_type = num_segments['dtype'] validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) @@ -1968,7 +1970,7 @@ class UnsortedSegmentMin(PrimitiveWithCheck): num_segments_type = num_segments['dtype'] validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) if isinstance(num_segments_type, type(mstype.tensor)): - validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64], + validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64], self.name) else: validator.check_value_type('num_segments', num_segments['value'], [int], self.name) @@ -2021,7 +2023,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck): num_segments_type = num_segments['dtype'] validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) if isinstance(num_segments_type, type(mstype.tensor)): - validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64], + validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64], self.name) else: validator.check_value_type('num_segments', num_segments['value'], [int], self.name)