From f2529129360c9d9f38000acb7d0598bff5c85dee Mon Sep 17 00:00:00 2001 From: zhaoting Date: Tue, 16 Mar 2021 20:04:45 +0800 Subject: [PATCH] fix cpu UnsortedSegmentSum --- .../cpu/unsorted_segment_sum_cpu_kernel.cc | 37 ++++--------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc index d9afeb211c..a6505b0927 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc @@ -16,7 +16,6 @@ #include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h" #include -#include #include "runtime/device/cpu/cpu_device_address.h" #include "common/thread_pool.h" @@ -79,37 +78,17 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector &in MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret; return false; } - size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); - if (unit_num_ < thread_num) { - thread_num = unit_num_; - } - std::vector tasks; - tasks.reserve(thread_num); - auto task = [&](size_t start, size_t end) { - for (size_t c = 0; c < ceil(static_cast(unit_num_) / thread_num); ++c) { - if (c * thread_num + start >= unit_num_) { - continue; - } - size_t i = c * thread_num + start; - size_t j = i / input_dim1_; - size_t k = i % input_dim1_; + for (size_t i = 0; i < unit_num_; ++i) { + size_t j = i / input_dim1_; + size_t k = i % input_dim1_; - T index = indices_addr[j]; - if (index < 0 || index >= SizeToInt(output_dim0_)) { - continue; - } - size_t output_index = index * output_dim1_ + k; - output_addr[output_index] += input_addr[i]; + T index = indices_addr[j]; + if (index < 0 || index >= SizeToInt(output_dim0_)) { + continue; } - }; - for (size_t t = 0; t < thread_num; ++t) { - auto block = [&, t]() { - task(t, t + 1); - return common::SUCCESS; - }; - tasks.emplace_back(block); + size_t output_index = index * output_dim1_ + k; + output_addr[output_index] += input_addr[i]; } - common::ThreadPool::GetInstance().SyncRun(tasks); return true; } } // namespace kernel