!13369 fix cpu UnsortedSegmentSum

From: @zhao_ting_v
Reviewed-by: @wuxuejian
Signed-off-by: @wuxuejian
pull/13369/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 754e253466

@ -16,7 +16,6 @@
#include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h"
#include <string>
#include <cmath>
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"
@ -79,18 +78,7 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector<AddressPtr> &in
MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret;
return false;
}
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (unit_num_ < thread_num) {
thread_num = unit_num_;
}
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
auto task = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(unit_num_) / thread_num); ++c) {
if (c * thread_num + start >= unit_num_) {
continue;
}
size_t i = c * thread_num + start;
for (size_t i = 0; i < unit_num_; ++i) {
size_t j = i / input_dim1_;
size_t k = i % input_dim1_;
@ -101,15 +89,6 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector<AddressPtr> &in
size_t output_index = index * output_dim1_ + k;
output_addr[output_index] += input_addr[i];
}
};
for (size_t t = 0; t < thread_num; ++t) {
auto block = [&, t]() {
task(t, t + 1);
return common::SUCCESS;
};
tasks.emplace_back(block);
}
common::ThreadPool::GetInstance().SyncRun(tasks);
return true;
}
} // namespace kernel

Loading…
Cancel
Save