!13369 fix cpu UnsortedSegmentSum

From: @zhao_ting_v
Reviewed-by: @wuxuejian
Signed-off-by: @wuxuejian
pull/13369/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 754e253466

@ -16,7 +16,6 @@
#include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h" #include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h"
#include <string> #include <string>
#include <cmath>
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h" #include "common/thread_pool.h"
@ -79,37 +78,17 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector<AddressPtr> &in
MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret; MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret;
return false; return false;
} }
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); for (size_t i = 0; i < unit_num_; ++i) {
if (unit_num_ < thread_num) { size_t j = i / input_dim1_;
thread_num = unit_num_; size_t k = i % input_dim1_;
}
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
auto task = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(unit_num_) / thread_num); ++c) {
if (c * thread_num + start >= unit_num_) {
continue;
}
size_t i = c * thread_num + start;
size_t j = i / input_dim1_;
size_t k = i % input_dim1_;
T index = indices_addr[j]; T index = indices_addr[j];
if (index < 0 || index >= SizeToInt(output_dim0_)) { if (index < 0 || index >= SizeToInt(output_dim0_)) {
continue; continue;
}
size_t output_index = index * output_dim1_ + k;
output_addr[output_index] += input_addr[i];
} }
}; size_t output_index = index * output_dim1_ + k;
for (size_t t = 0; t < thread_num; ++t) { output_addr[output_index] += input_addr[i];
auto block = [&, t]() {
task(t, t + 1);
return common::SUCCESS;
};
tasks.emplace_back(block);
} }
common::ThreadPool::GetInstance().SyncRun(tasks);
return true; return true;
} }
} // namespace kernel } // namespace kernel

Loading…
Cancel
Save